mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -14,29 +14,31 @@
14
14
  # ============================================================================
15
15
  """SymbolTree class define of Rewrite according to forward function of a network."""
16
16
  import stat
17
- from typing import Optional, Union, Tuple, Any
17
+ from typing import Optional, Union, Tuple, Any, Dict, List
18
18
  import os
19
19
  import sys
20
20
  import ast
21
21
  import importlib.util
22
- import types
23
22
  import time
24
- import astunparse
25
23
 
26
24
  from mindspore.nn import Cell
27
25
  from mindspore import log as logger
28
- from mindspore.rewrite.ast_creator_register import ast_creator_registry
29
- from .node import Node, TreeNode
26
+ from .node.node import Node, TreeNode
30
27
  from .api.node_type import NodeType
31
- from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, CheckPropertyIsUsed
28
+ from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
32
29
  from .api.scoped_value import ScopedValue, ValueType
33
30
  from .symbol_tree_dumper import SymbolTreeDumper
34
- from .topological_manager import TopoManager
31
+ from .node.node_topological_manager import TopoManager
35
32
  from .namer import TargetNamer, NodeNamer, ClassNamer
36
33
  from .common.observer import Observer
37
34
  from .common.observable import Observable
38
35
  from .common.event import Event
36
+ from .node.node_manager import NodeManager
39
37
 
38
+ if sys.version_info >= (3, 9):
39
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
40
+ else:
41
+ import astunparse
40
42
 
41
43
  class Position:
42
44
  """
@@ -80,6 +82,7 @@ class FieldFinder(AstFinder):
80
82
  Args:
81
83
  scope (ast.AST): An instance of ast node as search scope.
82
84
  """
85
+
83
86
  def __init__(self, scope: ast.AST):
84
87
  super().__init__(scope)
85
88
  self._result = False
@@ -133,7 +136,7 @@ class IfFixer(ast.NodeTransformer):
133
136
  self.generic_visit(node)
134
137
 
135
138
 
136
- class SymbolTree(Observer, Observable):
139
+ class SymbolTree(Observer, Observable, NodeManager):
137
140
  """
138
141
  A symbol-tree usually corresponding to forward method of a network.
139
142
 
@@ -146,147 +149,138 @@ class SymbolTree(Observer, Observable):
146
149
  """
147
150
 
148
151
  def __init__(self, origin_network: Cell, module_ast: ast.Module):
149
- super().__init__()
152
+ Observer.__init__(self)
150
153
  Observable.__init__(self)
151
- origin_network_key = "handler"
154
+ self._node_namer = NodeNamer()
155
+ self._node_namer.add_name('obj')
156
+ NodeManager.__init__(self, self._node_namer)
157
+ NodeManager.reg_observer(self, observer=self)
152
158
  # init unique-namers
153
159
  self._target_namer = TargetNamer()
154
- self._node_name_namer = NodeNamer()
155
- # name or node would use as name of field, so name of origin network handler field should be added into \
156
- # _node_name_namer.
157
- self._node_name_namer.add_name(origin_network_key)
158
- self._topo_mgr = TopoManager(self)
159
- self._topo_mgr.reg_observer(self)
160
-
161
- self._nodes: {str, Node} = {}
162
- # parameters of forward method
163
- self._inputs: [Node] = []
160
+ # input arguments of function
164
161
  self._ori_cls_name = type(origin_network).__name__
165
162
  self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
163
+ NodeManager.set_manager_name(self, self._opt_cls_name)
166
164
  self._origin_network = origin_network
167
165
  self._module_ast: ast.Module = module_ast
166
+ self._import_asts: Optional[ast.Ast] = []
168
167
  self._class_ast: Optional[ast.ClassDef] = None
169
168
  self._root_ast: Optional[ast.FunctionDef] = None
170
169
  self._init_func_ast: Optional[ast.FunctionDef] = None
171
170
  self._deleted_field = {}
172
171
  self._deleted_node = []
173
- self._external_func_ast = []
172
+ self._external_ast = []
174
173
  self._father_class_ast = []
175
-
176
- # head node is always point to the first node(in source code order) of SymbolTree
177
- self._head = None
178
- # tail node is always point to the last node(in source code order) of SymbolTree
179
- self._tail = None
180
- self._return: Optional[Node] = None
181
-
182
174
  self._modified = False
183
- self._node_visitor = None
184
-
185
175
  self._tmp_file_limits = 20
186
176
  self._tmp_files = []
187
177
  self._saved_file_name = "./network_define.py"
188
178
  # used to insert "sys.path.append(xxx)"
189
179
  self._net_file_paths = []
180
+ self._tmp_import_strs = []
181
+ self._tmp_unmodified_strees: {type, str} = {}
182
+ self._tmp_replacers = []
183
+ # Record imported modules and names of each files
184
+ # The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
185
+ # Format: {file_path: {module: [name, ...], ...}, ...}
186
+ self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
190
187
 
191
188
  def __del__(self):
192
189
  for tmp_file in self._tmp_files:
193
190
  tmp_file.close()
194
191
 
195
192
  @staticmethod
196
- def _find_consumers_and_providers(nodes: [Node]):
197
- """
198
- Find consumers and providers for all nodes according to their targets and arguments.
199
- """
200
- consumers: {ScopedValue: [Node]} = {}
201
- providers: {ScopedValue: Node} = {}
202
- for node in nodes:
203
- for arg in node.get_args():
204
- if consumers.get(arg):
205
- consumers[arg].append(node)
206
- else:
207
- consumers[arg] = [node]
208
- for _, arg in node.get_kwargs():
209
- if consumers.get(arg):
210
- consumers[arg].append(node)
211
- else:
212
- consumers[arg] = [node]
213
- for target in node.get_targets():
214
- if providers.get(target) is not None:
215
- raise RuntimeError(f"Target({target}) of node duplicated")
216
- providers[target] = node
217
- return consumers, providers
218
-
219
- @staticmethod
220
- def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
221
- """Find all non-duplicated class name of SymbolTree recursively."""
222
- replacer = AstReplacer(stree.get_class_ast())
223
- replacers.append(replacer)
224
- for node in stree.nodes():
225
- if not isinstance(node, TreeNode):
193
+ def _remove_unused_import(module_ast):
194
+ """remove unused import in self._module_ast"""
195
+ str_checker = StrChecker(module_ast)
196
+ for i in range(len(module_ast.body) - 1, -1, -1):
197
+ body = module_ast.body[i]
198
+ if not isinstance(body, (ast.Import, ast.ImportFrom)):
226
199
  continue
227
- if node.symbol_tree.get_class_ast() is None:
200
+ if isinstance(body, ast.Import):
228
201
  continue
229
- sub_stree: SymbolTree = node.symbol_tree
230
- SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
231
- # all modified ast.ClassDef should export to code
232
- if sub_stree._modified:
233
- allow_class_name.append(sub_stree._class_ast.name)
202
+ if isinstance(body, ast.ImportFrom) and body.module == "cell":
203
+ module_ast.body.remove(body)
234
204
  continue
235
- # all un-modified ast.ClassDef only keep one instance
236
- seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
237
- if seen_cls_name is not None:
238
- replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
239
- else:
240
- seen_class[type(sub_stree.get_origin_network())] = sub_stree.get_class_ast().name
241
- allow_class_name.append(sub_stree.get_class_ast().name)
205
+ for alias in body.names:
206
+ name = alias.asname if alias.asname else alias.name
207
+ if not str_checker.check(name):
208
+ if len(body.names) == 1:
209
+ module_ast.body.remove(body)
210
+ i += 1
211
+ else:
212
+ body.names.remove(alias)
213
+
214
+ @staticmethod
215
+ def _remove_duplicated_import(module_ast):
216
+ """Remove duplicated import of 'net'."""
217
+ imports = set()
218
+ futures = set()
219
+ classes = set()
220
+
221
+ class TransImportNode(ast.NodeTransformer):
222
+ """Find all import nodes from input ast node."""
223
+
224
+ def visit_ClassDef(self, node: ast.ClassDef) -> Any:
225
+ class_str = astunparse.unparse(node)
226
+ if class_str not in classes:
227
+ classes.add(node.name)
228
+ return node
229
+ return
230
+
231
+ def visit_Try(self, node: ast.Try) -> Any:
232
+ if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
233
+ import_str = astunparse.unparse(node)
234
+ if import_str not in imports:
235
+ imports.add(import_str)
236
+ return node
237
+ return
238
+
239
+ def visit_Import(self, node: ast.Import) -> Any:
240
+ import_str = astunparse.unparse(node)
241
+ if import_str not in imports:
242
+ imports.add(import_str)
243
+ return node
244
+ return
245
+
246
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
247
+ """
248
+ Once the father class 'A' is defined in the current module, all the next imported class 'A' should
249
+ be removed. e.g.
250
+ def class A():
251
+ ...
252
+ from xxx import A, B
253
+ =>
254
+ def class A():
255
+ ...
256
+ from xxx import B
257
+ """
258
+ import_str = astunparse.unparse(node)
259
+
260
+ if import_str not in imports:
261
+ imports.add(import_str)
262
+ # remove "__future__" module
263
+ if node.module == '__future__':
264
+ futures.add(node.module)
265
+ return
266
+ # remove modules which have been defined in the code file
267
+ # it occurs when class A is a father class and other sub-classes import A
268
+ for alias in node.names[:]:
269
+ if alias.name in classes:
270
+ node.names.remove(alias)
271
+ # if the alias(es) in node.names are all removed, this import statement should be removed
272
+ if not node.names:
273
+ return
274
+ return node
275
+ return
276
+
277
+ get_node_handler = TransImportNode()
278
+ get_node_handler.generic_visit(module_ast)
242
279
 
243
280
  def finish_build(self):
244
281
  """Add Event.TopologicalChangeEvent event when build is finished."""
245
282
  self.add_event(Event.TopologicalChangeEvent)
246
283
 
247
- def create_assign_node(self, targets, func_name, args, kwargs):
248
- """
249
- Create a ast.Assign type node.
250
-
251
- Args:
252
- targets (list): _description_
253
- func_name (_type_): _description_
254
- args (_type_): _description_
255
- kwargs (_type_): _description_
256
-
257
- Returns:
258
- _type_: _description_
259
- """
260
- # create targets
261
- ast_targets = [ast_creator_registry.get("Name")(targets)]
262
- # create call
263
- ast_func = ast_creator_registry.get("Attribute")(func_name)
264
- ast_args = ast_creator_registry.get("Args")(args)
265
- ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
266
- ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
267
- # create assign
268
- ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
269
- return ast_node
270
-
271
- def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
272
- '''
273
- Instantiate an instance of node whose type is `CallFunction`.
274
-
275
- Args:
276
- node_name (str): Name of node.
277
- func_name (str): Name of function.
278
- ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
279
- targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
280
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
281
- args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
282
- kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
283
- class.
284
- '''
285
- logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
286
- node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
287
- node.set_belong_symbol_tree(self)
288
- return node
289
-
290
284
  def get_ori_cls_name(self) -> str:
291
285
  """
292
286
  Get class name of original network.
@@ -342,6 +336,7 @@ class SymbolTree(Observer, Observable):
342
336
  corresponding network class.
343
337
  """
344
338
  self._root_ast = ast_node
339
+ NodeManager.set_ast_functiondef(self, ast_node)
345
340
 
346
341
  def get_class_ast(self):
347
342
  """
@@ -380,18 +375,6 @@ class SymbolTree(Observer, Observable):
380
375
  """
381
376
  self._init_func_ast = ast_node
382
377
 
383
- def get_inputs(self):
384
- return self._inputs
385
-
386
- def get_head_node(self):
387
- """
388
- Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
389
-
390
- Returns:
391
- An instance of node.
392
- """
393
- return self._head
394
-
395
378
  def get_origin_network(self):
396
379
  """
397
380
  Getter of `_origin_network`.
@@ -405,46 +388,53 @@ class SymbolTree(Observer, Observable):
405
388
  """Get dict of nodes"""
406
389
  return self._nodes
407
390
 
408
- def get_father_class_ast(self):
409
- """Get _father_class_ast"""
410
- return self._father_class_ast
391
+ def get_node_namer(self):
392
+ """Get _node_namer"""
393
+ return self._node_namer
411
394
 
412
- def append_net_file_path(self, file_path):
413
- """Append a file_path into _net_file_paths"""
414
- if file_path not in self._net_file_paths:
415
- self._net_file_paths.append(file_path)
416
-
417
- def get_net_file_path(self):
418
- """Get _net_file_paths"""
419
- return self._net_file_paths
395
+ def is_modified(self):
396
+ """
397
+ Check whether symbol tree is modified.
420
398
 
421
- def nodes(self):
399
+ Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
400
+ the symbol tree is created.
422
401
  """
423
- Get generator of nodes of current `SymbolTree`.
402
+ return self._modified
424
403
 
425
- Returns:
426
- A generator for iterating Nodes of `SymbolTree`.
404
+ def set_modified_true(self):
427
405
  """
428
- # Put nodes in the list to avoid iteration stops caused by node topology being modified
429
- nodes = []
430
- node = self._head
431
- while node is not None:
432
- nodes.append(node)
433
- node = node.get_next()
434
- return iter(nodes)
406
+ Set self._modified true.
435
407
 
436
- def get_node(self, node_name: str) -> Optional[Node]:
408
+ Self._modified is set true when 'if' exists in the original network.
409
+ In this situation, different original network instance tends to be different.
410
+ Hence, the class name should be updated.
437
411
  """
438
- Get node of current symbol_tree by `node_name`.
412
+ self._modified = True
439
413
 
440
- Args:
441
- node_name (str): A str represents name of node as key of query.
414
+ def get_import_asts(self):
415
+ """Get _import_asts"""
416
+ return self._import_asts
442
417
 
443
- Returns:
444
- An instance of Node if found else None.
445
- """
418
+ def get_external_ast(self):
419
+ """Get _external_ast"""
420
+ return self._external_ast
421
+
422
+ def get_father_class_ast(self):
423
+ """Get _father_class_ast"""
424
+ return self._father_class_ast
425
+
426
+ def get_imported_modules(self, file_path: str):
427
+ """Get all modules and module_paths in file of `file_path` ."""
428
+ return self._imported_modules.get(file_path, {})
446
429
 
447
- return self._nodes.get(node_name)
430
+ def save_imported_modules(self, file_path: str, module: str, names: List[str]):
431
+ """Save module and names into _imported_modules."""
432
+ imported_modules = self.get_imported_modules(file_path)
433
+ if imported_modules.get(module):
434
+ imported_modules[module].extend(names)
435
+ else:
436
+ imported_modules[module] = names
437
+ self._imported_modules[file_path] = imported_modules
448
438
 
449
439
  def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
450
440
  """
@@ -535,9 +525,11 @@ class SymbolTree(Observer, Observable):
535
525
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
536
526
  return Position.create(node.get_belong_symbol_tree(), node, False)
537
527
 
538
- def insert_node(self, position: Optional[Position], node: Node, insert_to_ast: bool = True) -> Node:
528
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
529
+ insert_to_ast: bool = True):
539
530
  """
540
- Insert a node into SymbolTree.
531
+ Insert a node before or after base_node.
532
+
541
533
  Note:
542
534
  Name of node will be unique while inserting node into SymbolTree.
543
535
 
@@ -556,57 +548,73 @@ class SymbolTree(Observer, Observable):
556
548
  Topological relation is updated and inputs of corresponding node is updated.
557
549
 
558
550
  Args:
559
- position (Position): A Position indicates an insert position point.
560
- node (Node): An instance of node to be inserted in.
561
- insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
562
- True.
551
+ new_node (Node): Node to be inserted.
552
+ base_node (Node): New node will be inserted before or after base_node.
553
+ before_node (bool): Indicate whether new node is inserted before base_node.
554
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
555
+ NodeManager of symboltree's construct function.
556
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
563
557
 
564
558
  Returns:
565
559
  An instance of node which has been inserted into SymbolTree.
566
560
 
567
561
  Raises:
568
562
  ValueError: Node in the SymbolTree is inserted into SymbolTree again.
569
- RuntimeError: If 'position' is not in current SymbolTree.
570
563
  RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
571
564
  """
572
- if node in self.nodes():
573
- raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {node.get_name()}")
574
- if position is not None and hasattr(position.node, "container"):
575
- cellcontainer = getattr(position.node, "container")
576
- index = cellcontainer.node_list.index(position.node)
577
- index = index if position.before_node else index + 1
578
- cellcontainer.insert(index, node)
579
- return node
580
- # if position in current SymbolTree
581
- if position is not None and position.symbol_tree is not self:
582
- raise RuntimeError("Position is not in current SymbolTree:", position)
583
- if position is not None and position.node.get_node_type() == NodeType.Input:
565
+ if new_node.get_belong_symbol_tree():
566
+ raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
567
+
568
+ # Check if base_node in current SymbolTree
569
+ if base_node is not None:
570
+ stree = base_node.get_belong_symbol_tree()
571
+ if stree is not None and stree is not self:
572
+ raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
573
+ f"current: {self.get_ori_cls_name()}.")
574
+
575
+ # Check if node is inserted between Input node
576
+ if base_node is not None and base_node.get_node_type() == NodeType.Input:
584
577
  valid = True
585
- if position.before_node:
578
+ if before_node:
586
579
  valid = False
587
- if position.node.get_next() is not None and position.node.get_next().get_node_type() == NodeType.Input:
580
+ if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
588
581
  valid = False
589
582
  if not valid:
590
- raise RuntimeError("Can not insert a node before or between parameters:", position)
591
- # unique node name while insert node into symbol_tree
592
- node_name = self._node_name_namer.get_name(node)
593
- node.set_name(node_name)
583
+ raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
584
+
594
585
  # save target name, which is used to provide unique target
595
- if node.get_targets():
596
- for target in node.get_targets():
586
+ if new_node.get_targets():
587
+ for target in new_node.get_targets():
597
588
  self._target_namer.add_name(str(target))
598
- self._handle_custom_obj_in_normalized_args(node)
599
- self._insert_node(position, node)
600
- if isinstance(node, TreeNode):
601
- node.symbol_tree.reg_observer(self)
602
- if self._node_visitor:
603
- self._node_visitor.append_node(node)
604
- # update init-function-ast and construct-function-ast
605
- if insert_to_ast:
606
- self._insert_to_ast_while_insert_node(node, position)
607
- return node
608
589
 
609
- def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
590
+ self._handle_custom_obj_in_normalized_args(new_node)
591
+
592
+ # Insert node into NodeManager
593
+ if node_manager is None:
594
+ if base_node is None:
595
+ raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
596
+ node_manager = base_node.get_node_manager()
597
+
598
+ # set node's _belong_symbol_tree
599
+ new_node.set_belong_symbol_tree(self)
600
+
601
+ if node_manager is self:
602
+ NodeManager.insert_node(self, new_node, base_node, before_node)
603
+ if insert_to_ast:
604
+ # update init-function-ast and construct-function-ast
605
+ self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
606
+ else:
607
+ node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
608
+
609
+ # register code changed event observer, which is used to update _modified flag.
610
+ if new_node.get_node_type() == NodeType.Tree:
611
+ new_node.symbol_tree.reg_observer(self)
612
+ elif isinstance(new_node, NodeManager):
613
+ new_node.reg_observer(self)
614
+
615
+ return new_node
616
+
617
+ def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
610
618
  """
611
619
  Append a node to SymbolTree.
612
620
 
@@ -614,13 +622,17 @@ class SymbolTree(Observer, Observable):
614
622
  node (Node): An instance of node to be appended.
615
623
  append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
616
624
  True.
625
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
626
+ NodeManager of symboltree's construct function.
617
627
 
618
628
  Returns:
619
629
  An instance of node which has been appended to SymbolTree.
620
630
  """
621
- return self.insert_node(Position.create(self, self._tail, False), node, append_to_ast)
631
+ if node_manager is None:
632
+ node_manager = self
633
+ return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
622
634
 
623
- def append_origin_field(self, node: Node) -> Node:
635
+ def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
624
636
  """
625
637
  Append an original field node to SymbolTree. An original field node represents a node created from existing
626
638
  statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
@@ -629,26 +641,16 @@ class SymbolTree(Observer, Observable):
629
641
 
630
642
  Args:
631
643
  node (Node): An instance of node to be appended.
644
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
645
+ NodeManager of symboltree's construct function.
632
646
 
633
647
  Returns:
634
648
  An instance of node which has been appended to SymbolTree.
635
649
  """
636
- if node.get_node_type() == NodeType.Output:
637
- self._return = node
638
- elif node.get_node_type() == NodeType.Input:
639
- self._inputs.append(node)
640
- elif node.get_node_type() == NodeType.Tree:
641
- # add father_class_ast into main tree, used when get_code
642
- for father_ast in node.symbol_tree.get_father_class_ast():
643
- if father_ast not in self._father_class_ast:
644
- self._father_class_ast.append(father_ast)
645
- # add subtree's net path into main tree
646
- for file_path in node.symbol_tree.get_net_file_path():
647
- if file_path not in self._net_file_paths:
648
- self.append_net_file_path(file_path)
649
- return self.append_node(node, False)
650
-
651
- def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None):
650
+ return self.append_node(node, node_manager, False)
651
+
652
+ def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
653
+ node_manager: NodeManager = None):
652
654
  """
653
655
  Append an input node to SymbolTree corresponding to parameter of forward method of network class.
654
656
  This method is called while building SymbolTree usually.
@@ -658,13 +660,18 @@ class SymbolTree(Observer, Observable):
658
660
  param_name (str): A str represents name of parameter of forward method of network class.
659
661
  default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
660
662
  means parameter has no default value.
663
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
664
+ NodeManager of symboltree's construct function.
661
665
 
662
666
  Returns:
663
667
  An instance of input node which has been appended to SymbolTree.
664
668
  """
665
669
  if param_name == "self":
666
670
  return
667
- for input_node in self._inputs:
671
+ # check param_name duplicated
672
+ if node_manager is None:
673
+ node_manager = self
674
+ for input_node in node_manager._inputs:
668
675
  targets = input_node.get_targets()
669
676
  if len(targets) != 1:
670
677
  raise RuntimeError("targets should have 1 elements")
@@ -677,9 +684,10 @@ class SymbolTree(Observer, Observable):
677
684
  if exist_param == param_name:
678
685
  raise RuntimeError("input duplicated:", param_name)
679
686
  input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
680
- self.append_origin_field(input_node)
687
+ self.append_origin_field(input_node, node_manager)
681
688
 
682
- def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Optional[Node]:
689
+ def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
690
+ node_manager: NodeManager = None) -> Optional[Node]:
683
691
  """
684
692
  Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
685
693
  a list or a dict.
@@ -688,6 +696,8 @@ class SymbolTree(Observer, Observable):
688
696
  Args:
689
697
  ast_scope (ast.AST): A ast node represents ast node of scope of node.
690
698
  ast_node (ast.AST): A ast node represents ast node.
699
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
700
+ NodeManager of symboltree's construct function.
691
701
 
692
702
  Returns:
693
703
  An instance of python node if a new node has been appended to SymbolTree else None.
@@ -696,9 +706,9 @@ class SymbolTree(Observer, Observable):
696
706
  return None
697
707
  if isinstance(ast_node, (list, dict)) and not ast_node:
698
708
  return None
699
- return self.append_python_node(ast_scope, ast_node)
709
+ return self.append_python_node(ast_scope, ast_node, node_manager)
700
710
 
701
- def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
711
+ def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
702
712
  """
703
713
  Append a python node to SymbolTree.
704
714
  This method is called while building SymbolTree usually.
@@ -706,39 +716,50 @@ class SymbolTree(Observer, Observable):
706
716
  Args:
707
717
  ast_scope (ast.AST): A ast node represents ast node of scope of node.
708
718
  ast_node (ast.AST): A ast node represents ast node.
719
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
720
+ NodeManager of symboltree's construct function.
709
721
 
710
722
  Returns:
711
723
  An instance of python node which has been appended to SymbolTree.
712
724
  """
713
725
  logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
714
- node_name = self._node_name_namer.get_name(type(ast_node).__name__)
726
+ node_name = type(ast_node).__name__
715
727
  node = Node.create_python_node(ast_node, node_name)
716
- self._insert_node(Position.create(self, self._tail, False), node)
728
+ if node_manager is None or node_manager is self:
729
+ NodeManager.append_python_node(self, node)
730
+ else:
731
+ node_manager.append_python_node(node)
717
732
  return node
718
733
 
719
- def set_output(self, return_value: str, index: int) -> Node:
734
+ def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
735
+ node_manager: NodeManager = None) -> Node:
720
736
  """
721
737
  Update return value of return of forward method of network class.
722
738
 
723
739
  Args:
724
740
  return_value (str): A str represents new return value.
725
- index (int): A int indicates which return value to be updated.
741
+ arg_index (int): A int indicates which value in return to be updated.
742
+ return_idx (int): A int indicates which return node to be updated. Default: 0.
743
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
744
+ symboltree's construct function.
726
745
 
727
746
  Returns:
728
747
  An instance of node represents return node after updated.
729
748
  """
730
- if self._return is None:
731
- raise RuntimeError("SymbolTree has no output")
732
- self.set_node_arg(self._return, index, return_value)
733
- return self._return
749
+ node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
750
+ if not node_returns:
751
+ raise RuntimeError("Current node_manager has no output")
752
+ if return_idx >= len(node_returns):
753
+ raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
754
+ node_return = node_returns[return_idx]
755
+ self.set_node_arg(node_return, arg_index, return_value)
756
+ return node_return
734
757
 
735
758
  def erase_node(self, node_or_name: Union[Node, str]) -> Node:
736
759
  """
737
760
  Erase a node from SymbolTree.
738
- Note:
739
- If node is depended on by other node, RuntimeError will raise.
740
761
 
741
- Topological relation is updated.
762
+ Topological relation will be updated.
742
763
 
743
764
  Args:
744
765
  node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
@@ -754,19 +775,21 @@ class SymbolTree(Observer, Observable):
754
775
  node = self._get_real_node(node_or_name)
755
776
  if node is None:
756
777
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
757
- if hasattr(node, "container"):
758
- cellcontainer = getattr(node, "container")
759
- cellcontainer.erase(node)
760
- return node
761
- ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
762
- if not ret:
763
- raise RuntimeError("node not in function ast tree.")
764
- self._topo_mgr.on_erase_node(node)
765
- for key, value in self._nodes.items():
766
- if id(value) == id(node):
767
- self._nodes.pop(key)
768
- value.isolate()
769
- break
778
+ # erase node in NodeManager
779
+ node_manager = node.get_node_manager()
780
+
781
+ logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
782
+ f"node_manager: {node_manager.get_manager_name()}, "
783
+ f"code: {astunparse.unparse(node.get_ast()).strip()}, "
784
+ f"node_name:{node.get_name()}")
785
+
786
+ if node_manager is self:
787
+ NodeManager.erase_node(self, node)
788
+ ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
789
+ if not ret:
790
+ raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
791
+ else:
792
+ node_manager.erase_node(node)
770
793
  self._deleted_node.append(node.get_name())
771
794
  return node
772
795
 
@@ -785,25 +808,16 @@ class SymbolTree(Observer, Observable):
785
808
  RuntimeError: If 'old_node' is isolated.
786
809
  RuntimeError: If 'old_node' is not belong to current SymbolTree.
787
810
  """
788
-
789
- if hasattr(old_node, "container"):
790
- self._replace_container_node(old_node, new_nodes)
791
- return new_nodes[0]
792
811
  real_old_node = self._get_real_node(old_node)
793
812
  if real_old_node is None:
794
813
  raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
795
- # get position
796
- next_node: Node = old_node.get_next()
797
- prev_node: Node = old_node.get_prev()
798
- if prev_node is None and next_node is None:
799
- raise RuntimeError("Try replacing a isolated node: ", old_node)
800
- if prev_node is None:
801
- position = self.before(next_node)
802
- else:
803
- position = self.after(prev_node)
814
+ # insert new_nodes into node_manager
815
+ node_manager = real_old_node.get_node_manager()
816
+ # insert new_nodes into NodeManager
817
+ base_node = old_node
804
818
  for node in new_nodes:
805
- self.insert_node(position, node, True)
806
- position = self.after(node)
819
+ self.insert_node(node, base_node, False, node_manager, True)
820
+ base_node = node
807
821
  self.erase_node(old_node)
808
822
  return new_nodes[-1]
809
823
 
@@ -868,6 +882,15 @@ class SymbolTree(Observer, Observable):
868
882
  """Get a unique name in the symboltree"""
869
883
  return self._target_namer.get_name(name)
870
884
 
885
+ def unique_func_name(self, name: str):
886
+ """Get a unique function name in the symboltree"""
887
+ if not hasattr(self._origin_network, name):
888
+ return name
889
+ suffix = 1
890
+ while hasattr(self._origin_network, f"{name}_{suffix}"):
891
+ suffix += 1
892
+ return f"{name}_{suffix}"
893
+
871
894
  def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
872
895
  """
873
896
  Set target of `node` .
@@ -895,21 +918,191 @@ class SymbolTree(Observer, Observable):
895
918
  node.set_targets(targets)
896
919
  self._topo_mgr.on_update_target(node, index, old_target, target)
897
920
 
898
- def print_node_tabulate(self):
899
- print(self._topo_mgr.dump())
921
+ def all_nodes(self):
922
+ """
923
+ Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
924
+
925
+ Returns:
926
+ A list of nodes.
927
+ """
928
+ nodes = []
929
+ node_managers = [self]
930
+ while node_managers:
931
+ node_manager = node_managers.pop()
932
+ nodes.extend(node_manager.nodes())
933
+ for node in node_manager.nodes():
934
+ if isinstance(node, NodeManager):
935
+ node_managers.append(node)
936
+ for tree_node in self.get_tree_nodes():
937
+ stree = tree_node.symbol_tree
938
+ nodes.extend(stree.all_nodes())
939
+ return nodes
940
+
941
+ def get_node_from_name(self, node_name: str):
942
+ """
943
+ Get node from all NodeManagers in current symbol tree by `node_name`.
944
+
945
+ Args:
946
+ node_name (str): A str represents name of node as key of query.
947
+
948
+ Returns:
949
+ An instance of Node if found else None.
950
+ """
951
+ node_managers = [self]
952
+ while node_managers:
953
+ node_manager = node_managers.pop()
954
+ node = node_manager.get_node(node_name)
955
+ if node:
956
+ return node
957
+ for node in node_manager.nodes():
958
+ if isinstance(node, NodeManager):
959
+ node_managers.append(node)
960
+ return None
961
+
962
+ def print_node_tabulate(self, all_nodes: bool = False):
963
+ """
964
+ Print nodes information and nodes' topological relations.
965
+
966
+ Args:
967
+ all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
968
+ nodes, CellContainer nodes and sub symbol trees.
969
+ """
970
+ try:
971
+ from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
972
+ except ImportError:
973
+ logger.warning("print_node_tabulate relies on the library `tabulate`, "
974
+ "which could not be found on this machine. Run `pip "
975
+ "install tabulate` to install the library.")
976
+ return ""
977
+ print(NodeManager.dump(self, self.get_manager_name()))
978
+ if all_nodes:
979
+ node_managers = [self]
980
+ while node_managers:
981
+ node_manager = node_managers.pop()
982
+ for node in node_manager.nodes():
983
+ if isinstance(node, NodeManager):
984
+ print(node.dump(node.get_manager_name()))
985
+ node_managers.append(node)
986
+ for tree_node in self.get_tree_nodes():
987
+ stree = tree_node.symbol_tree
988
+ stree.print_node_tabulate(all_nodes)
900
989
 
901
990
  def dump(self):
902
991
  """Dump graph."""
903
992
  dump_st = SymbolTreeDumper(self)
904
993
  dump_st.dump()
905
994
 
906
- def update_module_ast(self):
907
- for node in self._external_func_ast:
908
- self._module_ast.body.append(node)
909
- # Put father asts in front of first ClassDef
910
- index = [type(body) for body in self._module_ast.body].index(ast.ClassDef)
911
- for node in reversed(self._father_class_ast):
912
- self._module_ast.body.insert(index, node)
995
+ def check_body_exist(self, body, code_bodies):
996
+ """Check whether body already exist in code_bodies"""
997
+ # Check import ast node exist by saving import code string to self._tmp_import_strs
998
+ if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
999
+ import_str = astunparse.unparse(body)
1000
+ if import_str in self._tmp_import_strs:
1001
+ return True
1002
+ self._tmp_import_strs.append(import_str)
1003
+ return False
1004
+
1005
+ # Check ClassDef ast node exist by using AstClassFinder
1006
+ if isinstance(body, ast.ClassDef):
1007
+ if sys.version_info >= (3, 9):
1008
+ class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1009
+ else:
1010
+ class_finder = AstClassFinder(ast.Module(body=code_bodies))
1011
+ results = class_finder.find_all(body.name)
1012
+ return bool(results)
1013
+
1014
+ # Check FunctionDef ast node exist by using AstFunctionFinder
1015
+ if isinstance(body, ast.FunctionDef):
1016
+ if sys.version_info >= (3, 9):
1017
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
1018
+ else:
1019
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
1020
+ results = function_finder.find_all(body.name)
1021
+ return bool(results)
1022
+
1023
+ return False
1024
+
1025
+ def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
1026
+ """
1027
+ For the unmodified symbol tree, only one definition code remains in the generated code.
1028
+ Everywhere else calling this symbol tree will use the class in this definition code.
1029
+ """
1030
+ # all modified ast.ClassDef will be exported to code
1031
+ if stree.is_modified():
1032
+ return False
1033
+ # all un-modified ast.ClassDef only keep one instance
1034
+ first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1035
+ if first_cls_name is None:
1036
+ class_ast = stree.get_class_ast()
1037
+ if class_ast:
1038
+ self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
1039
+ return False
1040
+ # Un-modified ast.ClassDef already exist in code_bodies,
1041
+ # replace class name to class name of first un-modified ast.ClassDef.
1042
+ if sys.version_info >= (3, 9):
1043
+ replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1044
+ else:
1045
+ replacer = AstReplacer(ast.Module(body=code_bodies))
1046
+ replacer.replace_all(stree.get_class_ast().name, first_cls_name)
1047
+ self._tmp_replacers.append(replacer)
1048
+ return True
1049
+
1050
+ def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
1051
+ """
1052
+ Convert nodes in stree to code_bodies
1053
+
1054
+ 1. Add import asts into code_bodies
1055
+ 2. Add class, function and other type of asts into code_bodies
1056
+ 3. Add father class asts into code_bodies
1057
+ 4. Add external function asts into code_bodies
1058
+ 5. Add subtrees to code_bodies
1059
+ 5.1 Add subtrees in construct to code_bodies
1060
+ 5.2 Add subtrees in CellContainers to code_bodies
1061
+
1062
+ """
1063
+ # Add import asts into code_bodies
1064
+ for body in stree.get_import_asts():
1065
+ if not self.check_body_exist(body, code_bodies):
1066
+ code_bodies.insert(insert_pos, body)
1067
+ insert_pos += 1
1068
+
1069
+ # Add class, function and other type of asts into code_bodies
1070
+ if stree.get_module_ast():
1071
+ for body in stree.get_module_ast().body:
1072
+ if self.check_body_exist(body, code_bodies):
1073
+ continue
1074
+ if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
1075
+ code_bodies.insert(insert_pos, body)
1076
+ else:
1077
+ code_bodies.append(body)
1078
+
1079
+ # Add father class asts into code_bodies
1080
+ for body in reversed(stree.get_father_class_ast()):
1081
+ if self.check_body_exist(body, code_bodies):
1082
+ # remove exist ast in old position, then insert ast to upper position
1083
+ if sys.version_info >= (3, 9):
1084
+ exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
1085
+ else:
1086
+ exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
1087
+ code_bodies.remove(exist_ast)
1088
+ code_bodies.insert(insert_pos, body)
1089
+
1090
+ # Add external asts into code_bodies
1091
+ for body in stree.get_external_ast():
1092
+ if not self.check_body_exist(body, code_bodies):
1093
+ code_bodies.insert(insert_pos, body)
1094
+ insert_pos += 1
1095
+
1096
+ # Add subtrees to code_bodies
1097
+ for node in stree.get_tree_nodes():
1098
+ sub_stree = node.symbol_tree
1099
+ # Ignore TreeNode create by function in the class
1100
+ if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
1101
+ continue
1102
+ # For the unmodified class, update class name to name of first class
1103
+ if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
1104
+ continue
1105
+ self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
913
1106
 
914
1107
  def get_code(self) -> str:
915
1108
  """
@@ -918,34 +1111,22 @@ class SymbolTree(Observer, Observable):
918
1111
  Returns:
919
1112
  A str represents source code of modified network.
920
1113
  """
921
- self._remove_unused_import()
922
- if self._init_func_ast:
923
- self._remove_unused_field()
924
- self._remove_duplicated_import()
925
- self.update_module_ast()
1114
+ self._tmp_import_strs.clear()
1115
+ self._tmp_unmodified_strees.clear()
1116
+ self._tmp_replacers.clear()
1117
+ code_bodies = []
1118
+ self.convert_stree_to_code_bodies(self, code_bodies)
1119
+ if sys.version_info >= (3, 9):
1120
+ gencode_module = ast.Module(body=code_bodies, type_ignores=[])
1121
+ else:
1122
+ gencode_module = ast.Module(body=code_bodies)
1123
+ SymbolTree._remove_unused_import(gencode_module)
1124
+ SymbolTree._remove_duplicated_import(gencode_module)
926
1125
  ast.fix_missing_locations(self._module_ast)
927
- # Find all ast.ClassDef which can be export to code
928
- # Replace duplicated ast.ClassDef reference in main-ClassDef
929
- seen_class: {type, str} = {}
930
- allow_class_name = [self._class_ast.name]
931
- replacers = []
932
- SymbolTree._find_all_class_in_symboltree(self, seen_class, allow_class_name, replacers)
933
- # Add all non-ClassDef body to gencode_module
934
- # Add all ClassDef in allow_class_name to gencode_module
935
- # Use gencode_module to generate code
936
- bodies = []
937
- for body in self._module_ast.body:
938
- if not isinstance(body, ast.ClassDef):
939
- bodies.append(body)
940
- continue
941
- if body.name in allow_class_name:
942
- bodies.append(body)
943
- gencode_module = ast.Module(body=bodies)
944
- if_fixer = IfFixer()
945
- if_fixer.fix(gencode_module)
1126
+ IfFixer().fix(gencode_module)
946
1127
  code = astunparse.unparse(gencode_module)
947
- # Restore main-ClassDef
948
- for replacer in replacers:
1128
+ # Revert the class name to its original state
1129
+ for replacer in self._tmp_replacers:
949
1130
  replacer.undo_all()
950
1131
  return code
951
1132
 
@@ -979,251 +1160,71 @@ class SymbolTree(Observer, Observable):
979
1160
  f.write(source.encode('utf-8'))
980
1161
  f.flush()
981
1162
 
982
- def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
1163
+ def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
1164
+ node_manager: NodeManager):
983
1165
  """ insert_to_ast_while_insert_node. """
984
- node.set_func(ScopedValue.create_naming_value(node.get_name(), "self"))
985
- node_ast = node.get_ast()
986
- if not isinstance(node_ast, ast.Assign):
987
- raise RuntimeError("Only support insert cell op now")
988
- if isinstance(node, TreeNode):
989
- setattr(self._origin_network, node.get_name(), node.get_instance())
990
- args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
991
- [ScopedValue(ValueType.NamingValue, "", "obj"),
992
- ScopedValue(ValueType.StringValue, "", node.get_name())])
993
- value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
994
- col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
995
-
996
- ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
997
- assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
998
- AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
999
-
1000
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1001
- None if position is None else position.node.get_ast(),
1002
- position.before_node)
1003
- sub_stree: SymbolTree = node.symbol_tree
1004
- from .symbol_tree_builder import SymbolTreeBuilder
1005
- SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
1166
+ if new_node.get_node_type() == NodeType.Input:
1167
+ # insert a new input
1168
+ self._inputs.append(new_node)
1169
+ ast_construct = self.get_ast_root()
1170
+ arg: str = new_node.get_targets()[0].value
1171
+ ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
1172
+ AstModifier.append_arg_to_function(ast_construct, ast_arg)
1006
1173
  else:
1007
- AstModifier.insert_assign_to_function(self._init_func_ast,
1008
- targets=[ScopedValue(ValueType.NamingValue, "self", node.get_name())],
1009
- expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
1010
- args=[ScopedValue(ValueType.NamingValue, "", "obj"),
1011
- ScopedValue(ValueType.StringValue, "", node.get_name())])
1012
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1013
- None if position is None else position.node.get_ast(),
1014
- position.before_node)
1015
- setattr(self._origin_network, node.get_name(), node.get_instance())
1016
-
1017
- def _remove_unused_import(self):
1018
- """remove unused import in self._module_ast"""
1019
- str_checker = StrChecker(self._module_ast)
1020
- for i in range(len(self._module_ast.body) - 1, -1, -1):
1021
- body = self._module_ast.body[i]
1022
- if not isinstance(body, (ast.Import, ast.ImportFrom)):
1023
- continue
1024
- if isinstance(body, ast.Import):
1025
- continue
1026
- if isinstance(body, ast.ImportFrom) and body.module == "cell":
1027
- self._module_ast.body.remove(body)
1028
- continue
1029
- for alias in body.names:
1030
- name = alias.asname if alias.asname else alias.name
1031
- if not str_checker.check(name):
1032
- if len(body.names) == 1:
1033
- self._module_ast.body.remove(body)
1034
- i += 1
1035
- else:
1036
- body.names.remove(alias)
1037
-
1038
- def _replace_container_node(self, old_node, new_nodes):
1039
- cellcontainer = getattr(old_node, "container")
1040
- index = cellcontainer.node_list.index(old_node)
1041
- for n in reversed(new_nodes):
1042
- cellcontainer.insert(index, n)
1043
- index = cellcontainer.node_list.index(old_node)
1044
- cellcontainer.erase(old_node)
1045
-
1046
- def _filter_out_to_delete_field(self, to_delete_field):
1047
- """filter out used field from `to_delete_field`"""
1048
- for func_def in self._class_ast.body:
1049
- if not isinstance(func_def, ast.FunctionDef):
1050
- continue
1051
- if func_def.name != "__init__":
1052
- to_delete_to_delete_keys = []
1053
- property_checker = CheckPropertyIsUsed(func_def)
1054
- for key, _ in self._deleted_field.items():
1055
- if property_checker.check("self", key):
1056
- to_delete_to_delete_keys.append(key)
1057
- property_checker = CheckPropertyIsUsed(func_def)
1058
- for key in to_delete_to_delete_keys:
1059
- self._deleted_field.pop(key)
1174
+ # insert a new assign statement
1175
+ ast_assign = new_node.get_ast()
1176
+ if ast_assign is None:
1177
+ func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
1178
+ new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
1179
+ ast_assign = new_node.update_ast_node()
1180
+ if not isinstance(ast_assign, ast.Assign):
1181
+ raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
1182
+ # Save instance into _origin_network.
1183
+ setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
1184
+ # Insert ast to __init__ function
1185
+ if isinstance(new_node, TreeNode):
1186
+ init_code = f"self.{new_node.get_name()} = " \
1187
+ f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
1060
1188
  else:
1061
- for body in func_def.body:
1062
- if not isinstance(body, ast.If):
1063
- continue
1064
- test = body.test
1065
- field_finder = FieldFinder(test)
1066
- to_delete_to_delete_keys = []
1067
- for key, _ in self._deleted_field.items():
1068
- if field_finder.check(key):
1069
- to_delete_to_delete_keys.append(key)
1070
- for key in to_delete_to_delete_keys:
1071
- self._deleted_field.pop(key)
1072
-
1073
- def _remove_unused_field(self):
1074
- """remove unused field in __init__ function"""
1075
- multi_targets = []
1076
- self._deleted_field = {}
1077
- for index, body in enumerate(self._init_func_ast.body):
1078
- if not isinstance(body, ast.Assign):
1079
- continue
1080
- targets = body.targets
1081
- for target in targets:
1082
- if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
1083
- and target.value.id == "self":
1084
- self._deleted_field[target.attr] = index
1085
- if len(targets) > 1:
1086
- multi_targets.append(index)
1087
- self._filter_out_to_delete_field(self._deleted_field)
1088
- for i in range(len(self._init_func_ast.body) - 1, -1, -1):
1089
- if i in self._deleted_field.values():
1090
- if i in multi_targets:
1091
- raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
1092
- AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
1093
- ast.fix_missing_locations(self._init_func_ast)
1094
-
1095
- def _remove_duplicated_import(self):
1096
- """Remove duplicated import of 'net'."""
1097
- imports = []
1098
- for body in self._module_ast.body:
1099
- if isinstance(body, (ast.ImportFrom, ast.Import)):
1100
- import_str = astunparse.unparse(body)
1101
- if import_str not in imports:
1102
- imports.append(import_str)
1103
- else:
1104
- self._module_ast.body.remove(body)
1189
+ init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
1190
+ init_ast = ast.parse(init_code).body[0]
1191
+ AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
1192
+ # Insert ast to construct_function/class_internal_function
1193
+ ast_base_node = base_node.get_ast() if base_node else None
1194
+ ast_functiondef = node_manager.get_ast_functiondef()
1195
+ if not ast_functiondef:
1196
+ raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
1197
+ "when inserting the ast.")
1198
+ AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
1105
1199
 
1106
1200
  def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1107
1201
  if isinstance(node_or_name, str):
1108
1202
  return self.get_node(node_or_name)
1109
1203
  return node_or_name
1110
1204
 
1111
- def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
1112
- """
1113
- Insert a node-tree into SymbolTree.
1114
- Note:
1115
- Inputs of intra sub-tree nodes need to be welly set.
1116
-
1117
- Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
1118
-
1119
- Args:
1120
- position (Position): A Position indicates an insert position point.
1121
- root (Node): An instance of node as root of node-tree to be inserted in.
1122
- insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
1123
- True.
1124
-
1125
- Returns:
1126
- An instance of node as root node of node-tree which has been inserted into SymbolTree.
1127
-
1128
- Raises:
1129
- RuntimeError: If 'position' is not in current SymbolTree.
1130
- """
1131
-
1132
- # if position not in current SymbolTree
1133
- if position.symbol_tree is not self:
1134
- raise RuntimeError("Position is not in current SymbolTree: ", position)
1135
-
1136
- queue: [Node] = [root]
1137
- todos: [] = []
1138
- inputs_list: [] = []
1139
- while queue:
1140
- cur_node = queue.pop(0)
1141
- if cur_node in todos:
1142
- continue
1143
- todos.append(cur_node)
1144
- node_inputs = cur_node.get_inputs()
1145
- inputs_list.append(node_inputs)
1146
- for node_input in node_inputs:
1147
- if node_input is not None:
1148
- queue.append(node_input)
1149
- todos.reverse()
1150
- inputs_list.reverse()
1151
- for index, todo in enumerate(todos):
1152
- self.insert_node(position, todo, insert_to_ast)
1153
- position = self.after(todo)
1154
- # relink input of node
1155
- original_inputs = inputs_list[index]
1156
- for arg_idx, original_input in enumerate(original_inputs):
1157
- if original_input is not None:
1158
- self.set_node_arg_by_node(todo, arg_idx, original_input)
1159
- return root
1160
-
1161
- def _add_node2nodes(self, node: Node):
1162
- """
1163
- Add `node` to `_nodes` dict.
1164
-
1165
- Args:
1166
- node (Node): A Node to be added into `_nodes`.
1167
-
1168
- Raises:
1169
- RuntimeError: If name of the node is duplicated.
1170
- """
1171
- node_name = node.get_name()
1172
- if self._nodes.get(node_name) is not None:
1173
- raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
1174
- node)
1175
- self._nodes[node_name] = node
1176
-
1177
- def _insert_node(self, position: Optional[Position], node: Node):
1178
- """
1179
- Insert a node into SymbolTree.
1180
- 1. Add `node` to `_nodes`.
1181
- 2. Insert `node` to node list(source code order).
1182
- 3. Update topological relation and update inputs of `node`.
1183
-
1184
- Args:
1185
- position ([Position, optional]): Indicates node insert position. Position is None when inserting first node
1186
- of SymbolTree.
1187
- node (Node): A Node to be inserted into SymbolTree.
1188
-
1189
- Raises:
1190
- RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
1191
- unless inserting first node.
1192
- """
1193
- if position is None:
1194
- if self._nodes:
1195
- raise RuntimeError("self._nodes should be empty")
1196
- self._head = node
1197
- else:
1198
- if position.before_node:
1199
- position.node.insert_before(node)
1200
- else:
1201
- position.node.insert_after(node)
1202
- self._tail = node
1203
- self._add_node2nodes(node)
1204
- self._topo_mgr.on_insert_node(node)
1205
- node.set_belong_symbol_tree(self)
1206
-
1207
1205
  def _handle_custom_obj_in_normalized_args(self, node: Node):
1208
1206
  """
1209
- Convert CustomObjValue type argument to NamingValue type argument by storing custom object in global_vars dict.
1207
+ Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
1210
1208
 
1211
1209
  Args:
1212
1210
  node (Node): A Node whose arguments and keyword arguments to be handled.
1213
1211
  """
1214
- result: {str, ScopedValue} = {}
1215
- for arg, value in node.get_normalized_args().items():
1212
+ normalized_args: {str, ScopedValue} = {}
1213
+ for key, value in node.get_normalized_args().items():
1216
1214
  if not isinstance(value, ScopedValue):
1217
1215
  raise TypeError("value should be ScopedValue, got: ", type(value))
1218
1216
  if value.type == ValueType.CustomObjValue:
1219
- field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
1220
- setattr(self._origin_network, field, value.value)
1221
- init_targets = [ScopedValue.create_naming_value(field, "self")]
1222
- AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
1223
- result[arg] = init_targets[0]
1217
+ # Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
1218
+ arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
1219
+ setattr(self._origin_network, arg_name, value.value)
1220
+ # Add new code to __init__(): self.arg_name = obj.arg_name
1221
+ new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
1222
+ self._init_func_ast.body.append(new_ast)
1223
+ # Modify node's normalized_args: CustomObjValue -> self.arg_name
1224
+ normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
1224
1225
  else:
1225
- result[arg] = value
1226
- node.set_normalized_args(result)
1226
+ normalized_args[key] = value
1227
+ node.set_normalized_args(normalized_args)
1227
1228
 
1228
1229
  def _get_cls_through_file(self):
1229
1230
  """
@@ -1235,12 +1236,14 @@ class SymbolTree(Observer, Observable):
1235
1236
  Returns:
1236
1237
  A class handle.
1237
1238
  """
1238
- self._update_container()
1239
1239
  file_path = os.getcwd()
1240
1240
  file_path = os.path.join(file_path, "rewritten_network")
1241
1241
  if not os.path.exists(file_path):
1242
- os.mkdir(file_path)
1243
- file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
1242
+ try:
1243
+ os.mkdir(file_path, mode=0o700)
1244
+ except FileExistsError:
1245
+ pass
1246
+ file_name = f"{self._opt_cls_name}_{id(self)}.py"
1244
1247
  network_file = os.path.join(file_path, file_name)
1245
1248
  with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1246
1249
  source = self.get_code()
@@ -1277,21 +1280,6 @@ class SymbolTree(Observer, Observable):
1277
1280
  self._modified = True
1278
1281
  self.changed(event)
1279
1282
 
1280
- def _update_container(self):
1281
- """Update instance of node in container."""
1282
- for node in self.nodes():
1283
- index = 0
1284
- if node.get_node_type() == NodeType.CellContainer:
1285
- for n in node.node_list:
1286
- if not n.valid:
1287
- continue
1288
- if n.get_node_type() == NodeType.Tree:
1289
- obj = n.symbol_tree.get_network()
1290
- node.get_instance()[index] = obj
1291
- else:
1292
- node.get_instance()[index] = n.get_instance()
1293
- index += 1
1294
-
1295
1283
  def _cal_difference_set(self, input, other):
1296
1284
  """Calculate different set of two sets."""
1297
1285
  set1 = set(input)
@@ -1313,43 +1301,3 @@ class SymbolTree(Observer, Observable):
1313
1301
  primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1314
1302
  for p in primitives:
1315
1303
  new_net._primitives[p] = self._origin_network._primitives[p]
1316
-
1317
- def _create_call_function(self, func, targets, args, kwargs):
1318
- """
1319
- Create a Node object and generate the execution code to insert into the source code.
1320
- The source code calls the 'func' function with 'args' and' kwargs' as parameters.
1321
-
1322
- Args:
1323
- func (FunctionType) - The function to be called.
1324
- targets (list [str]) - indicates the output name. As the output of the node in the source code.
1325
- args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
1326
- code. The default value is None, which means there is no parameter input in the cell.
1327
- kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
1328
- input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
1329
- code as the 'kwargs' in the statement expression. The default value is None, which means there is no
1330
- 'kwargs' input.
1331
-
1332
- Returns:
1333
- An instance of `Node`.
1334
- """
1335
- if not isinstance(func, types.FunctionType):
1336
- raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
1337
-
1338
- _package = func.__globals__['__package__']
1339
- func_name = ".".join([_package, func.__name__]) if _package else func.__name__
1340
-
1341
- ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
1342
- scope_targets = [ScopedValue.create_naming_value(targets[0])]
1343
- scope_func = ScopedValue.create_naming_value(func_name, "")
1344
- call_args = list()
1345
- for arg in args:
1346
- if isinstance(arg, Node):
1347
- call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
1348
- else:
1349
- call_args.append(ScopedValue.create_variable_value(arg))
1350
- call_kwargs = {}
1351
- for k, v in kwargs.items():
1352
- call_kwargs[k] = ScopedValue.create_variable_value(v)
1353
- node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
1354
- call_kwargs)
1355
- return node