mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (580) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +141 -88
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/cell_wrapper.py +84 -34
  268. mindspore/nn/wrap/grad_reducer.py +8 -5
  269. mindspore/nn/wrap/loss_scale.py +105 -42
  270. mindspore/numpy/array_creations.py +1 -2
  271. mindspore/numpy/array_ops.py +3 -2
  272. mindspore/numpy/utils_const.py +5 -5
  273. mindspore/offline_debug/convert_async.py +2 -2
  274. mindspore/ops/_grad_experimental/__init__.py +0 -5
  275. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  276. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  277. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  278. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  279. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  280. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  281. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  282. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  283. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  284. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  285. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  286. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  287. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  288. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  289. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  290. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  291. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  292. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  293. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  294. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  295. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  296. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  297. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  298. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  299. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  300. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  301. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  302. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  303. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  304. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  305. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  306. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  307. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  309. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  310. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  311. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  312. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  313. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  314. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  315. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  316. mindspore/ops/_primitive_cache.py +1 -1
  317. mindspore/ops/_tracefunc.py +45 -13
  318. mindspore/ops/_utils/utils.py +6 -1
  319. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  320. mindspore/ops/_vmap/vmap_base.py +3 -3
  321. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  322. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  323. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  324. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  325. mindspore/ops/arg_dtype_cast.py +54 -0
  326. mindspore/ops/composite/base.py +37 -10
  327. mindspore/ops/composite/math_ops.py +5 -4
  328. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  329. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  330. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  331. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  332. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  333. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  334. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  335. mindspore/ops/deprecated.py +304 -0
  336. mindspore/ops/function/__init__.py +4 -1
  337. mindspore/ops/function/array_func.py +174 -193
  338. mindspore/ops/function/clip_func.py +81 -13
  339. mindspore/ops/function/debug_func.py +1 -1
  340. mindspore/ops/function/grad/grad_func.py +18 -9
  341. mindspore/ops/function/image_func.py +10 -4
  342. mindspore/ops/function/linalg_func.py +5 -5
  343. mindspore/ops/function/math_func.py +575 -386
  344. mindspore/ops/function/nn_func.py +568 -260
  345. mindspore/ops/function/random_func.py +88 -57
  346. mindspore/ops/function/sparse_func.py +1 -1
  347. mindspore/ops/function/sparse_unary_func.py +14 -12
  348. mindspore/ops/function/vmap_func.py +6 -5
  349. mindspore/ops/functional.py +15 -10
  350. mindspore/ops/op_info_register.py +244 -25
  351. mindspore/ops/operations/__init__.py +28 -19
  352. mindspore/ops/operations/_grad_ops.py +72 -7
  353. mindspore/ops/operations/_inner_ops.py +350 -17
  354. mindspore/ops/operations/_quant_ops.py +4 -8
  355. mindspore/ops/operations/_sequence_ops.py +42 -0
  356. mindspore/ops/operations/array_ops.py +68 -282
  357. mindspore/ops/operations/comm_ops.py +107 -59
  358. mindspore/ops/operations/custom_ops.py +94 -70
  359. mindspore/ops/operations/debug_ops.py +8 -4
  360. mindspore/ops/operations/image_ops.py +18 -12
  361. mindspore/ops/operations/inner_ops.py +26 -3
  362. mindspore/ops/operations/math_ops.py +189 -141
  363. mindspore/ops/operations/nn_ops.py +794 -489
  364. mindspore/ops/operations/other_ops.py +0 -22
  365. mindspore/ops/operations/random_ops.py +53 -111
  366. mindspore/ops/operations/sparse_ops.py +3 -1
  367. mindspore/ops/primitive.py +24 -18
  368. mindspore/parallel/_auto_parallel_context.py +68 -8
  369. mindspore/parallel/_cost_model_context.py +2 -2
  370. mindspore/parallel/_offload_context.py +17 -3
  371. mindspore/parallel/_parallel_serialization.py +12 -5
  372. mindspore/parallel/_ps_context.py +12 -0
  373. mindspore/parallel/_tensor.py +18 -13
  374. mindspore/parallel/_transformer/layers.py +5 -3
  375. mindspore/parallel/_transformer/loss.py +1 -0
  376. mindspore/parallel/_transformer/moe.py +2 -2
  377. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  378. mindspore/parallel/_transformer/transformer.py +23 -3
  379. mindspore/parallel/_utils.py +11 -7
  380. mindspore/parallel/algo_parameter_config.py +85 -5
  381. mindspore/parallel/checkpoint_transform.py +19 -12
  382. mindspore/parallel/shard.py +21 -14
  383. mindspore/profiler/common/struct_type.py +3 -3
  384. mindspore/profiler/common/util.py +4 -2
  385. mindspore/profiler/envprofiling.py +1 -1
  386. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  387. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  388. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  389. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  390. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  391. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  392. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  393. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  394. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  395. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  396. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  397. mindspore/profiler/parser/flops_parser.py +15 -11
  398. mindspore/profiler/parser/framework_parser.py +38 -22
  399. mindspore/profiler/parser/hccl_parser.py +16 -12
  400. mindspore/profiler/parser/integrator.py +22 -11
  401. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  402. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  403. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  404. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  405. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  406. mindspore/profiler/parser/optime_parser.py +1 -1
  407. mindspore/profiler/parser/profiler_info.py +21 -2
  408. mindspore/profiler/parser/step_trace_parser.py +11 -14
  409. mindspore/profiler/profiling.py +179 -89
  410. mindspore/rewrite/api/node.py +102 -19
  411. mindspore/rewrite/api/node_type.py +5 -1
  412. mindspore/rewrite/api/pattern_engine.py +1 -1
  413. mindspore/rewrite/api/scoped_value.py +9 -17
  414. mindspore/rewrite/api/symbol_tree.py +131 -47
  415. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  416. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  417. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  418. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  419. mindspore/rewrite/common/rewrite_elog.py +5 -1
  420. mindspore/rewrite/namer.py +33 -24
  421. mindspore/rewrite/namespace.py +14 -5
  422. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  423. mindspore/rewrite/node/call_function.py +79 -0
  424. mindspore/rewrite/node/cell_container.py +135 -0
  425. mindspore/rewrite/node/control_flow.py +88 -0
  426. mindspore/rewrite/{node.py → node/node.py} +273 -234
  427. mindspore/rewrite/node/node_manager.py +254 -0
  428. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  429. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  430. mindspore/rewrite/parsers/assign_parser.py +216 -221
  431. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  432. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  433. mindspore/rewrite/parsers/constant_parser.py +9 -6
  434. mindspore/rewrite/parsers/container_parser.py +9 -7
  435. mindspore/rewrite/parsers/for_parser.py +36 -15
  436. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  437. mindspore/rewrite/parsers/if_parser.py +28 -24
  438. mindspore/rewrite/parsers/module_parser.py +196 -25
  439. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  440. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  441. mindspore/rewrite/parsers/return_parser.py +6 -6
  442. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  443. mindspore/rewrite/sparsify/utils.py +1 -1
  444. mindspore/rewrite/symbol_tree.py +523 -578
  445. mindspore/rewrite/symbol_tree_builder.py +9 -193
  446. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  447. mindspore/run_check/_check_version.py +6 -4
  448. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  449. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  450. mindspore/scipy/linalg.py +1 -1
  451. mindspore/scipy/optimize/minimize.py +7 -3
  452. mindspore/train/_utils.py +7 -3
  453. mindspore/train/amp.py +323 -123
  454. mindspore/train/anf_ir_pb2.py +14 -2
  455. mindspore/train/callback/_backup_and_restore.py +2 -12
  456. mindspore/train/callback/_callback.py +29 -4
  457. mindspore/train/callback/_checkpoint.py +23 -8
  458. mindspore/train/callback/_early_stop.py +2 -2
  459. mindspore/train/callback/_landscape.py +4 -4
  460. mindspore/train/callback/_loss_monitor.py +2 -2
  461. mindspore/train/callback/_on_request_exit.py +2 -2
  462. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  463. mindspore/train/callback/_summary_collector.py +15 -8
  464. mindspore/train/callback/_time_monitor.py +58 -5
  465. mindspore/train/data_sink.py +5 -11
  466. mindspore/train/dataset_helper.py +84 -57
  467. mindspore/train/loss_scale_manager.py +2 -2
  468. mindspore/train/metrics/__init__.py +3 -3
  469. mindspore/train/metrics/cosine_similarity.py +1 -1
  470. mindspore/train/metrics/hausdorff_distance.py +3 -2
  471. mindspore/train/metrics/mean_surface_distance.py +3 -2
  472. mindspore/train/metrics/metric.py +39 -19
  473. mindspore/train/metrics/roc.py +2 -2
  474. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  475. mindspore/train/mind_ir_pb2.py +85 -36
  476. mindspore/train/model.py +187 -47
  477. mindspore/train/serialization.py +487 -161
  478. mindspore/train/summary/_summary_adapter.py +1 -1
  479. mindspore/train/summary/_writer_pool.py +3 -2
  480. mindspore/train/summary/summary_record.py +37 -17
  481. mindspore/train/train_thor/convert_utils.py +3 -3
  482. mindspore/train/train_thor/dataset_helper.py +1 -1
  483. mindspore/version.py +1 -1
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  487. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  489. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  490. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  491. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  492. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  493. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  494. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  495. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  497. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  498. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  499. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  500. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  501. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  502. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  503. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  504. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  505. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  506. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  507. mindspore/_extends/graph_kernel/expander.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  509. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  510. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  511. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  512. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  513. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  514. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  515. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  516. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  517. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  518. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  519. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  520. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  521. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  522. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  523. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  524. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  525. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  526. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  527. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  528. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  529. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  530. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  531. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  532. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  534. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  535. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  536. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  537. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  538. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  539. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  543. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  544. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  545. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  546. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  547. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  548. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  549. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  550. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  551. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  552. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  553. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  554. mindspore/dataset/datapreprocess/__init__.py +0 -20
  555. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  556. mindspore/include/api/net.h +0 -142
  557. mindspore/nn/lr_scheduler.py +0 -262
  558. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  559. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  560. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  561. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  562. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  563. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  567. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  575. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  578. mindspore/rewrite/node_visitor.py +0 -44
  579. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  580. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -14,29 +14,31 @@
14
14
  # ============================================================================
15
15
  """SymbolTree class define of Rewrite according to forward function of a network."""
16
16
  import stat
17
- from typing import Optional, Union, Tuple, Any
17
+ from typing import Optional, Union, Tuple, Any, Dict, List
18
18
  import os
19
19
  import sys
20
20
  import ast
21
21
  import importlib.util
22
- import types
23
22
  import time
24
- import astunparse
25
23
 
26
24
  from mindspore.nn import Cell
27
25
  from mindspore import log as logger
28
- from mindspore.rewrite.ast_creator_register import ast_creator_registry
29
- from .node import Node, TreeNode
26
+ from .node.node import Node, TreeNode
30
27
  from .api.node_type import NodeType
31
- from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, CheckPropertyIsUsed
28
+ from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
32
29
  from .api.scoped_value import ScopedValue, ValueType
33
30
  from .symbol_tree_dumper import SymbolTreeDumper
34
- from .topological_manager import TopoManager
31
+ from .node.node_topological_manager import TopoManager
35
32
  from .namer import TargetNamer, NodeNamer, ClassNamer
36
33
  from .common.observer import Observer
37
34
  from .common.observable import Observable
38
35
  from .common.event import Event
36
+ from .node.node_manager import NodeManager
39
37
 
38
+ if sys.version_info >= (3, 9):
39
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
40
+ else:
41
+ import astunparse
40
42
 
41
43
  class Position:
42
44
  """
@@ -80,6 +82,7 @@ class FieldFinder(AstFinder):
80
82
  Args:
81
83
  scope (ast.AST): An instance of ast node as search scope.
82
84
  """
85
+
83
86
  def __init__(self, scope: ast.AST):
84
87
  super().__init__(scope)
85
88
  self._result = False
@@ -133,7 +136,7 @@ class IfFixer(ast.NodeTransformer):
133
136
  self.generic_visit(node)
134
137
 
135
138
 
136
- class SymbolTree(Observer, Observable):
139
+ class SymbolTree(Observer, Observable, NodeManager):
137
140
  """
138
141
  A symbol-tree usually corresponding to forward method of a network.
139
142
 
@@ -146,147 +149,135 @@ class SymbolTree(Observer, Observable):
146
149
  """
147
150
 
148
151
  def __init__(self, origin_network: Cell, module_ast: ast.Module):
149
- super().__init__()
152
+ Observer.__init__(self)
150
153
  Observable.__init__(self)
151
- origin_network_key = "handler"
154
+ self._node_namer = NodeNamer()
155
+ self._node_namer.add_name('obj')
156
+ NodeManager.__init__(self, self._node_namer)
157
+ NodeManager.reg_observer(self, observer=self)
152
158
  # init unique-namers
153
159
  self._target_namer = TargetNamer()
154
- self._node_name_namer = NodeNamer()
155
- # name or node would use as name of field, so name of origin network handler field should be added into \
156
- # _node_name_namer.
157
- self._node_name_namer.add_name(origin_network_key)
158
- self._topo_mgr = TopoManager(self)
159
- self._topo_mgr.reg_observer(self)
160
-
161
- self._nodes: {str, Node} = {}
162
- # parameters of forward method
163
- self._inputs: [Node] = []
160
+ # input arguments of function
164
161
  self._ori_cls_name = type(origin_network).__name__
165
162
  self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
163
+ NodeManager.set_manager_name(self, self._opt_cls_name)
166
164
  self._origin_network = origin_network
167
165
  self._module_ast: ast.Module = module_ast
166
+ self._import_asts: Optional[ast.Ast] = []
168
167
  self._class_ast: Optional[ast.ClassDef] = None
169
168
  self._root_ast: Optional[ast.FunctionDef] = None
170
169
  self._init_func_ast: Optional[ast.FunctionDef] = None
171
170
  self._deleted_field = {}
172
171
  self._deleted_node = []
173
- self._external_func_ast = []
172
+ self._external_ast = []
174
173
  self._father_class_ast = []
175
-
176
- # head node is always point to the first node(in source code order) of SymbolTree
177
- self._head = None
178
- # tail node is always point to the last node(in source code order) of SymbolTree
179
- self._tail = None
180
- self._return: Optional[Node] = None
181
-
182
174
  self._modified = False
183
- self._node_visitor = None
184
-
185
175
  self._tmp_file_limits = 20
186
176
  self._tmp_files = []
187
177
  self._saved_file_name = "./network_define.py"
188
178
  # used to insert "sys.path.append(xxx)"
189
179
  self._net_file_paths = []
180
+ self._tmp_import_strs = []
181
+ self._tmp_unmodified_strees: {type, str} = {}
182
+ self._tmp_replacers = []
183
+ # Record imported modules and names of each files
184
+ # The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
185
+ # Format: {file_path: {module: [name, ...], ...}, ...}
186
+ self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
190
187
 
191
188
  def __del__(self):
192
189
  for tmp_file in self._tmp_files:
193
190
  tmp_file.close()
194
191
 
195
192
  @staticmethod
196
- def _find_consumers_and_providers(nodes: [Node]):
197
- """
198
- Find consumers and providers for all nodes according to their targets and arguments.
199
- """
200
- consumers: {ScopedValue: [Node]} = {}
201
- providers: {ScopedValue: Node} = {}
202
- for node in nodes:
203
- for arg in node.get_args():
204
- if consumers.get(arg):
205
- consumers[arg].append(node)
206
- else:
207
- consumers[arg] = [node]
208
- for _, arg in node.get_kwargs():
209
- if consumers.get(arg):
210
- consumers[arg].append(node)
211
- else:
212
- consumers[arg] = [node]
213
- for target in node.get_targets():
214
- if providers.get(target) is not None:
215
- raise RuntimeError(f"Target({target}) of node duplicated")
216
- providers[target] = node
217
- return consumers, providers
218
-
219
- @staticmethod
220
- def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
221
- """Find all non-duplicated class name of SymbolTree recursively."""
222
- replacer = AstReplacer(stree.get_class_ast())
223
- replacers.append(replacer)
224
- for node in stree.nodes():
225
- if not isinstance(node, TreeNode):
193
+ def _remove_unused_import(module_ast):
194
+ """remove unused import in self._module_ast"""
195
+ str_checker = StrChecker(module_ast)
196
+ for i in range(len(module_ast.body) - 1, -1, -1):
197
+ body = module_ast.body[i]
198
+ if not isinstance(body, (ast.Import, ast.ImportFrom)):
226
199
  continue
227
- if node.symbol_tree.get_class_ast() is None:
200
+ if isinstance(body, ast.Import):
228
201
  continue
229
- sub_stree: SymbolTree = node.symbol_tree
230
- SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
231
- # all modified ast.ClassDef should export to code
232
- if sub_stree._modified:
233
- allow_class_name.append(sub_stree._class_ast.name)
202
+ if isinstance(body, ast.ImportFrom) and body.module == "cell":
203
+ module_ast.body.remove(body)
234
204
  continue
235
- # all un-modified ast.ClassDef only keep one instance
236
- seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
237
- if seen_cls_name is not None:
238
- replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
239
- else:
240
- seen_class[type(sub_stree.get_origin_network())] = sub_stree.get_class_ast().name
241
- allow_class_name.append(sub_stree.get_class_ast().name)
205
+ for alias in body.names:
206
+ name = alias.asname if alias.asname else alias.name
207
+ if not str_checker.check(name):
208
+ if len(body.names) == 1:
209
+ module_ast.body.remove(body)
210
+ i += 1
211
+ else:
212
+ body.names.remove(alias)
213
+
214
+ @staticmethod
215
+ def _remove_duplicated_import(module_ast):
216
+ """Remove duplicated import of 'net'."""
217
+ imports = set()
218
+ futures = set()
219
+ classes = set()
220
+
221
+ class TransImportNode(ast.NodeTransformer):
222
+ """Find all import nodes from input ast node."""
223
+
224
+ def visit_ClassDef(self, node: ast.ClassDef) -> Any:
225
+ class_str = astunparse.unparse(node)
226
+ if class_str not in classes:
227
+ classes.add(node.name)
228
+ return node
229
+
230
+ def visit_Try(self, node: ast.Try) -> Any:
231
+ if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
232
+ import_str = astunparse.unparse(node)
233
+ if import_str not in imports:
234
+ imports.add(import_str)
235
+ return node
236
+
237
+ def visit_Import(self, node: ast.Import) -> Any:
238
+ import_str = astunparse.unparse(node)
239
+ if import_str not in imports:
240
+ imports.add(import_str)
241
+ return node
242
+
243
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
244
+ """
245
+ Once the father class 'A' is defined in the current module, all the next imported class 'A' should
246
+ be removed. e.g.
247
+ def class A():
248
+ ...
249
+ from xxx import A, B
250
+ =>
251
+ def class A():
252
+ ...
253
+ from xxx import B
254
+ """
255
+ import_str = astunparse.unparse(node)
256
+
257
+ if import_str not in imports:
258
+ imports.add(import_str)
259
+ # remove "__future__" module
260
+ if node.module == '__future__':
261
+ futures.add(node.module)
262
+ return
263
+ # remove modules which have been defined in the code file
264
+ # it occurs when class A is a father class and other sub-classes import A
265
+ for alias in node.names[:]:
266
+ if alias.name in classes:
267
+ node.names.remove(alias)
268
+ # if the alias(es) in node.names are all removed, this import statement should be removed
269
+ if not node.names:
270
+ return
271
+ return node
272
+ return
273
+
274
+ get_node_handler = TransImportNode()
275
+ get_node_handler.generic_visit(module_ast)
242
276
 
243
277
  def finish_build(self):
244
278
  """Add Event.TopologicalChangeEvent event when build is finished."""
245
279
  self.add_event(Event.TopologicalChangeEvent)
246
280
 
247
- def create_assign_node(self, targets, func_name, args, kwargs):
248
- """
249
- Create a ast.Assign type node.
250
-
251
- Args:
252
- targets (list): _description_
253
- func_name (_type_): _description_
254
- args (_type_): _description_
255
- kwargs (_type_): _description_
256
-
257
- Returns:
258
- _type_: _description_
259
- """
260
- # create targets
261
- ast_targets = [ast_creator_registry.get("Name")(targets)]
262
- # create call
263
- ast_func = ast_creator_registry.get("Attribute")(func_name)
264
- ast_args = ast_creator_registry.get("Args")(args)
265
- ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
266
- ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
267
- # create assign
268
- ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
269
- return ast_node
270
-
271
- def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
272
- '''
273
- Instantiate an instance of node whose type is `CallFunction`.
274
-
275
- Args:
276
- node_name (str): Name of node.
277
- func_name (str): Name of function.
278
- ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
279
- targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
280
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
281
- args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
282
- kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
283
- class.
284
- '''
285
- logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
286
- node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
287
- node.set_belong_symbol_tree(self)
288
- return node
289
-
290
281
  def get_ori_cls_name(self) -> str:
291
282
  """
292
283
  Get class name of original network.
@@ -342,6 +333,7 @@ class SymbolTree(Observer, Observable):
342
333
  corresponding network class.
343
334
  """
344
335
  self._root_ast = ast_node
336
+ NodeManager.set_ast_functiondef(self, ast_node)
345
337
 
346
338
  def get_class_ast(self):
347
339
  """
@@ -380,18 +372,6 @@ class SymbolTree(Observer, Observable):
380
372
  """
381
373
  self._init_func_ast = ast_node
382
374
 
383
- def get_inputs(self):
384
- return self._inputs
385
-
386
- def get_head_node(self):
387
- """
388
- Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
389
-
390
- Returns:
391
- An instance of node.
392
- """
393
- return self._head
394
-
395
375
  def get_origin_network(self):
396
376
  """
397
377
  Getter of `_origin_network`.
@@ -405,46 +385,53 @@ class SymbolTree(Observer, Observable):
405
385
  """Get dict of nodes"""
406
386
  return self._nodes
407
387
 
408
- def get_father_class_ast(self):
409
- """Get _father_class_ast"""
410
- return self._father_class_ast
388
+ def get_node_namer(self):
389
+ """Get _node_namer"""
390
+ return self._node_namer
411
391
 
412
- def append_net_file_path(self, file_path):
413
- """Append a file_path into _net_file_paths"""
414
- if file_path not in self._net_file_paths:
415
- self._net_file_paths.append(file_path)
416
-
417
- def get_net_file_path(self):
418
- """Get _net_file_paths"""
419
- return self._net_file_paths
392
+ def is_modified(self):
393
+ """
394
+ Check whether symbol tree is modified.
420
395
 
421
- def nodes(self):
396
+ Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
397
+ the symbol tree is created.
422
398
  """
423
- Get generator of nodes of current `SymbolTree`.
399
+ return self._modified
424
400
 
425
- Returns:
426
- A generator for iterating Nodes of `SymbolTree`.
401
+ def set_modified_true(self):
427
402
  """
428
- # Put nodes in the list to avoid iteration stops caused by node topology being modified
429
- nodes = []
430
- node = self._head
431
- while node is not None:
432
- nodes.append(node)
433
- node = node.get_next()
434
- return iter(nodes)
403
+ Set self._modified true.
435
404
 
436
- def get_node(self, node_name: str) -> Optional[Node]:
405
+ Self._modified is set true when 'if' exists in the original network.
406
+ In this situation, different original network instance tends to be different.
407
+ Hence, the class name should be updated.
437
408
  """
438
- Get node of current symbol_tree by `node_name`.
409
+ self._modified = True
439
410
 
440
- Args:
441
- node_name (str): A str represents name of node as key of query.
411
+ def get_import_asts(self):
412
+ """Get _import_asts"""
413
+ return self._import_asts
442
414
 
443
- Returns:
444
- An instance of Node if found else None.
445
- """
415
+ def get_external_ast(self):
416
+ """Get _external_ast"""
417
+ return self._external_ast
418
+
419
+ def get_father_class_ast(self):
420
+ """Get _father_class_ast"""
421
+ return self._father_class_ast
422
+
423
+ def get_imported_modules(self, file_path: str):
424
+ """Get all modules and module_paths in file of `file_path` ."""
425
+ return self._imported_modules.get(file_path, {})
446
426
 
447
- return self._nodes.get(node_name)
427
+ def save_imported_modules(self, file_path: str, module: str, names: List[str]):
428
+ """Save module and names into _imported_modules."""
429
+ imported_modules = self.get_imported_modules(file_path)
430
+ if imported_modules.get(module):
431
+ imported_modules[module].extend(names)
432
+ else:
433
+ imported_modules[module] = names
434
+ self._imported_modules[file_path] = imported_modules
448
435
 
449
436
  def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
450
437
  """
@@ -535,9 +522,11 @@ class SymbolTree(Observer, Observable):
535
522
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
536
523
  return Position.create(node.get_belong_symbol_tree(), node, False)
537
524
 
538
- def insert_node(self, position: Optional[Position], node: Node, insert_to_ast: bool = True) -> Node:
525
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
526
+ insert_to_ast: bool = True):
539
527
  """
540
- Insert a node into SymbolTree.
528
+ Insert a node before or after base_node.
529
+
541
530
  Note:
542
531
  Name of node will be unique while inserting node into SymbolTree.
543
532
 
@@ -556,57 +545,73 @@ class SymbolTree(Observer, Observable):
556
545
  Topological relation is updated and inputs of corresponding node is updated.
557
546
 
558
547
  Args:
559
- position (Position): A Position indicates an insert position point.
560
- node (Node): An instance of node to be inserted in.
561
- insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
562
- True.
548
+ new_node (Node): Node to be inserted.
549
+ base_node (Node): New node will be inserted before or after base_node.
550
+ before_node (bool): Indicate whether new node is inserted before base_node.
551
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
552
+ NodeManager of symboltree's construct function.
553
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
563
554
 
564
555
  Returns:
565
556
  An instance of node which has been inserted into SymbolTree.
566
557
 
567
558
  Raises:
568
559
  ValueError: Node in the SymbolTree is inserted into SymbolTree again.
569
- RuntimeError: If 'position' is not in current SymbolTree.
570
560
  RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
571
561
  """
572
- if node in self.nodes():
573
- raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {node.get_name()}")
574
- if position is not None and hasattr(position.node, "container"):
575
- cellcontainer = getattr(position.node, "container")
576
- index = cellcontainer.node_list.index(position.node)
577
- index = index if position.before_node else index + 1
578
- cellcontainer.insert(index, node)
579
- return node
580
- # if position in current SymbolTree
581
- if position is not None and position.symbol_tree is not self:
582
- raise RuntimeError("Position is not in current SymbolTree:", position)
583
- if position is not None and position.node.get_node_type() == NodeType.Input:
562
+ if new_node.get_belong_symbol_tree():
563
+ raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
564
+
565
+ # Check if base_node in current SymbolTree
566
+ if base_node is not None:
567
+ stree = base_node.get_belong_symbol_tree()
568
+ if stree is not None and stree is not self:
569
+ raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
570
+ f"current: {self.get_ori_cls_name()}.")
571
+
572
+ # Check if node is inserted between Input node
573
+ if base_node is not None and base_node.get_node_type() == NodeType.Input:
584
574
  valid = True
585
- if position.before_node:
575
+ if before_node:
586
576
  valid = False
587
- if position.node.get_next() is not None and position.node.get_next().get_node_type() == NodeType.Input:
577
+ if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
588
578
  valid = False
589
579
  if not valid:
590
- raise RuntimeError("Can not insert a node before or between parameters:", position)
591
- # unique node name while insert node into symbol_tree
592
- node_name = self._node_name_namer.get_name(node)
593
- node.set_name(node_name)
580
+ raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
581
+
594
582
  # save target name, which is used to provide unique target
595
- if node.get_targets():
596
- for target in node.get_targets():
583
+ if new_node.get_targets():
584
+ for target in new_node.get_targets():
597
585
  self._target_namer.add_name(str(target))
598
- self._handle_custom_obj_in_normalized_args(node)
599
- self._insert_node(position, node)
600
- if isinstance(node, TreeNode):
601
- node.symbol_tree.reg_observer(self)
602
- if self._node_visitor:
603
- self._node_visitor.append_node(node)
604
- # update init-function-ast and construct-function-ast
605
- if insert_to_ast:
606
- self._insert_to_ast_while_insert_node(node, position)
607
- return node
608
586
 
609
- def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
587
+ self._handle_custom_obj_in_normalized_args(new_node)
588
+
589
+ # Insert node into NodeManager
590
+ if node_manager is None:
591
+ if base_node is None:
592
+ raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
593
+ node_manager = base_node.get_node_manager()
594
+
595
+ # set node's _belong_symbol_tree
596
+ new_node.set_belong_symbol_tree(self)
597
+
598
+ if node_manager is self:
599
+ NodeManager.insert_node(self, new_node, base_node, before_node)
600
+ if insert_to_ast:
601
+ # update init-function-ast and construct-function-ast
602
+ self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
603
+ else:
604
+ node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
605
+
606
+ # register code changed event observer, which is used to update _modified flag.
607
+ if new_node.get_node_type() == NodeType.Tree:
608
+ new_node.symbol_tree.reg_observer(self)
609
+ elif isinstance(new_node, NodeManager):
610
+ new_node.reg_observer(self)
611
+
612
+ return new_node
613
+
614
+ def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
610
615
  """
611
616
  Append a node to SymbolTree.
612
617
 
@@ -614,13 +619,17 @@ class SymbolTree(Observer, Observable):
614
619
  node (Node): An instance of node to be appended.
615
620
  append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
616
621
  True.
622
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
623
+ NodeManager of symboltree's construct function.
617
624
 
618
625
  Returns:
619
626
  An instance of node which has been appended to SymbolTree.
620
627
  """
621
- return self.insert_node(Position.create(self, self._tail, False), node, append_to_ast)
628
+ if node_manager is None:
629
+ node_manager = self
630
+ return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
622
631
 
623
- def append_origin_field(self, node: Node) -> Node:
632
+ def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
624
633
  """
625
634
  Append an original field node to SymbolTree. An original field node represents a node created from existing
626
635
  statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
@@ -629,26 +638,16 @@ class SymbolTree(Observer, Observable):
629
638
 
630
639
  Args:
631
640
  node (Node): An instance of node to be appended.
641
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
642
+ NodeManager of symboltree's construct function.
632
643
 
633
644
  Returns:
634
645
  An instance of node which has been appended to SymbolTree.
635
646
  """
636
- if node.get_node_type() == NodeType.Output:
637
- self._return = node
638
- elif node.get_node_type() == NodeType.Input:
639
- self._inputs.append(node)
640
- elif node.get_node_type() == NodeType.Tree:
641
- # add father_class_ast into main tree, used when get_code
642
- for father_ast in node.symbol_tree.get_father_class_ast():
643
- if father_ast not in self._father_class_ast:
644
- self._father_class_ast.append(father_ast)
645
- # add subtree's net path into main tree
646
- for file_path in node.symbol_tree.get_net_file_path():
647
- if file_path not in self._net_file_paths:
648
- self.append_net_file_path(file_path)
649
- return self.append_node(node, False)
650
-
651
- def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None):
647
+ return self.append_node(node, node_manager, False)
648
+
649
+ def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
650
+ node_manager: NodeManager = None):
652
651
  """
653
652
  Append an input node to SymbolTree corresponding to parameter of forward method of network class.
654
653
  This method is called while building SymbolTree usually.
@@ -658,13 +657,18 @@ class SymbolTree(Observer, Observable):
658
657
  param_name (str): A str represents name of parameter of forward method of network class.
659
658
  default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
660
659
  means parameter has no default value.
660
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
661
+ NodeManager of symboltree's construct function.
661
662
 
662
663
  Returns:
663
664
  An instance of input node which has been appended to SymbolTree.
664
665
  """
665
666
  if param_name == "self":
666
667
  return
667
- for input_node in self._inputs:
668
+ # check param_name duplicated
669
+ if node_manager is None:
670
+ node_manager = self
671
+ for input_node in node_manager._inputs:
668
672
  targets = input_node.get_targets()
669
673
  if len(targets) != 1:
670
674
  raise RuntimeError("targets should have 1 elements")
@@ -677,9 +681,10 @@ class SymbolTree(Observer, Observable):
677
681
  if exist_param == param_name:
678
682
  raise RuntimeError("input duplicated:", param_name)
679
683
  input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
680
- self.append_origin_field(input_node)
684
+ self.append_origin_field(input_node, node_manager)
681
685
 
682
- def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Optional[Node]:
686
+ def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
687
+ node_manager: NodeManager = None) -> Optional[Node]:
683
688
  """
684
689
  Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
685
690
  a list or a dict.
@@ -688,6 +693,8 @@ class SymbolTree(Observer, Observable):
688
693
  Args:
689
694
  ast_scope (ast.AST): A ast node represents ast node of scope of node.
690
695
  ast_node (ast.AST): A ast node represents ast node.
696
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
697
+ NodeManager of symboltree's construct function.
691
698
 
692
699
  Returns:
693
700
  An instance of python node if a new node has been appended to SymbolTree else None.
@@ -696,9 +703,9 @@ class SymbolTree(Observer, Observable):
696
703
  return None
697
704
  if isinstance(ast_node, (list, dict)) and not ast_node:
698
705
  return None
699
- return self.append_python_node(ast_scope, ast_node)
706
+ return self.append_python_node(ast_scope, ast_node, node_manager)
700
707
 
701
- def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
708
+ def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
702
709
  """
703
710
  Append a python node to SymbolTree.
704
711
  This method is called while building SymbolTree usually.
@@ -706,39 +713,50 @@ class SymbolTree(Observer, Observable):
706
713
  Args:
707
714
  ast_scope (ast.AST): A ast node represents ast node of scope of node.
708
715
  ast_node (ast.AST): A ast node represents ast node.
716
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
717
+ NodeManager of symboltree's construct function.
709
718
 
710
719
  Returns:
711
720
  An instance of python node which has been appended to SymbolTree.
712
721
  """
713
722
  logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
714
- node_name = self._node_name_namer.get_name(type(ast_node).__name__)
723
+ node_name = type(ast_node).__name__
715
724
  node = Node.create_python_node(ast_node, node_name)
716
- self._insert_node(Position.create(self, self._tail, False), node)
725
+ if node_manager is None or node_manager is self:
726
+ NodeManager.append_python_node(self, node)
727
+ else:
728
+ node_manager.append_python_node(node)
717
729
  return node
718
730
 
719
- def set_output(self, return_value: str, index: int) -> Node:
731
+ def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
732
+ node_manager: NodeManager = None) -> Node:
720
733
  """
721
734
  Update return value of return of forward method of network class.
722
735
 
723
736
  Args:
724
737
  return_value (str): A str represents new return value.
725
- index (int): A int indicates which return value to be updated.
738
+ arg_index (int): A int indicates which value in return to be updated.
739
+ return_idx (int): A int indicates which return node to be updated. Default: 0.
740
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
741
+ symboltree's construct function.
726
742
 
727
743
  Returns:
728
744
  An instance of node represents return node after updated.
729
745
  """
730
- if self._return is None:
731
- raise RuntimeError("SymbolTree has no output")
732
- self.set_node_arg(self._return, index, return_value)
733
- return self._return
746
+ node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
747
+ if not node_returns:
748
+ raise RuntimeError("Current node_manager has no output")
749
+ if return_idx >= len(node_returns):
750
+ raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
751
+ node_return = node_returns[return_idx]
752
+ self.set_node_arg(node_return, arg_index, return_value)
753
+ return node_return
734
754
 
735
755
  def erase_node(self, node_or_name: Union[Node, str]) -> Node:
736
756
  """
737
757
  Erase a node from SymbolTree.
738
- Note:
739
- If node is depended on by other node, RuntimeError will raise.
740
758
 
741
- Topological relation is updated.
759
+ Topological relation will be updated.
742
760
 
743
761
  Args:
744
762
  node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
@@ -754,19 +772,21 @@ class SymbolTree(Observer, Observable):
754
772
  node = self._get_real_node(node_or_name)
755
773
  if node is None:
756
774
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
757
- if hasattr(node, "container"):
758
- cellcontainer = getattr(node, "container")
759
- cellcontainer.erase(node)
760
- return node
761
- ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
762
- if not ret:
763
- raise RuntimeError("node not in function ast tree.")
764
- self._topo_mgr.on_erase_node(node)
765
- for key, value in self._nodes.items():
766
- if id(value) == id(node):
767
- self._nodes.pop(key)
768
- value.isolate()
769
- break
775
+ # erase node in NodeManager
776
+ node_manager = node.get_node_manager()
777
+
778
+ logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
779
+ f"node_manager: {node_manager.get_manager_name()}, "
780
+ f"code: {astunparse.unparse(node.get_ast()).strip()}, "
781
+ f"node_name:{node.get_name()}")
782
+
783
+ if node_manager is self:
784
+ NodeManager.erase_node(self, node)
785
+ ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
786
+ if not ret:
787
+ raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
788
+ else:
789
+ node_manager.erase_node(node)
770
790
  self._deleted_node.append(node.get_name())
771
791
  return node
772
792
 
@@ -785,26 +805,17 @@ class SymbolTree(Observer, Observable):
785
805
  RuntimeError: If 'old_node' is isolated.
786
806
  RuntimeError: If 'old_node' is not belong to current SymbolTree.
787
807
  """
788
-
789
- if hasattr(old_node, "container"):
790
- self._replace_container_node(old_node, new_nodes)
791
- return new_nodes[0]
792
808
  real_old_node = self._get_real_node(old_node)
793
809
  if real_old_node is None:
794
810
  raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
795
- # get position
796
- next_node: Node = old_node.get_next()
797
- prev_node: Node = old_node.get_prev()
798
- if prev_node is None and next_node is None:
799
- raise RuntimeError("Try replacing a isolated node: ", old_node)
800
- if prev_node is None:
801
- position = self.before(next_node)
802
- else:
803
- position = self.after(prev_node)
811
+ # insert new_nodes into node_manager
812
+ node_manager = real_old_node.get_node_manager()
813
+ # insert new_nodes into NodeManager
814
+ base_node = old_node
804
815
  for node in new_nodes:
805
- self.insert_node(position, node, True)
806
- position = self.after(node)
807
- self.erase_node(old_node)
816
+ self.insert_node(node, base_node, False, node_manager, True)
817
+ base_node = node
818
+ _ = self.erase_node(old_node)
808
819
  return new_nodes[-1]
809
820
 
810
821
  def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
@@ -868,6 +879,15 @@ class SymbolTree(Observer, Observable):
868
879
  """Get a unique name in the symboltree"""
869
880
  return self._target_namer.get_name(name)
870
881
 
882
+ def unique_func_name(self, name: str):
883
+ """Get a unique function name in the symboltree"""
884
+ if not hasattr(self._origin_network, name):
885
+ return name
886
+ suffix = 1
887
+ while hasattr(self._origin_network, f"{name}_{suffix}"):
888
+ suffix += 1
889
+ return f"{name}_{suffix}"
890
+
871
891
  def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
872
892
  """
873
893
  Set target of `node` .
@@ -895,21 +915,191 @@ class SymbolTree(Observer, Observable):
895
915
  node.set_targets(targets)
896
916
  self._topo_mgr.on_update_target(node, index, old_target, target)
897
917
 
898
- def print_node_tabulate(self):
899
- print(self._topo_mgr.dump())
918
+ def all_nodes(self):
919
+ """
920
+ Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
921
+
922
+ Returns:
923
+ A list of nodes.
924
+ """
925
+ nodes = []
926
+ node_managers = [self]
927
+ while node_managers:
928
+ node_manager = node_managers.pop()
929
+ nodes.extend(node_manager.nodes())
930
+ for node in node_manager.nodes():
931
+ if isinstance(node, NodeManager):
932
+ node_managers.append(node)
933
+ for tree_node in self.get_tree_nodes():
934
+ stree = tree_node.symbol_tree
935
+ nodes.extend(stree.all_nodes())
936
+ return nodes
937
+
938
+ def get_node_from_name(self, node_name: str):
939
+ """
940
+ Get node from all NodeManagers in current symbol tree by `node_name`.
941
+
942
+ Args:
943
+ node_name (str): A str represents name of node as key of query.
944
+
945
+ Returns:
946
+ An instance of Node if found else None.
947
+ """
948
+ node_managers = [self]
949
+ while node_managers:
950
+ node_manager = node_managers.pop()
951
+ node = node_manager.get_node(node_name)
952
+ if node:
953
+ return node
954
+ for node in node_manager.nodes():
955
+ if isinstance(node, NodeManager):
956
+ node_managers.append(node)
957
+ return None
958
+
959
+ def print_node_tabulate(self, all_nodes: bool = False):
960
+ """
961
+ Print nodes information and nodes' topological relations.
962
+
963
+ Args:
964
+ all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
965
+ nodes, CellContainer nodes and sub symbol trees.
966
+ """
967
+ try:
968
+ from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
969
+ except ImportError:
970
+ logger.warning("print_node_tabulate relies on the library `tabulate`, "
971
+ "which could not be found on this machine. Run `pip "
972
+ "install tabulate` to install the library.")
973
+ return ""
974
+ print(NodeManager.dump(self, self.get_manager_name()))
975
+ if all_nodes:
976
+ node_managers = [self]
977
+ while node_managers:
978
+ node_manager = node_managers.pop()
979
+ for node in node_manager.nodes():
980
+ if isinstance(node, NodeManager):
981
+ print(node.dump(node.get_manager_name()))
982
+ node_managers.append(node)
983
+ for tree_node in self.get_tree_nodes():
984
+ stree = tree_node.symbol_tree
985
+ stree.print_node_tabulate(all_nodes)
900
986
 
901
987
  def dump(self):
902
988
  """Dump graph."""
903
989
  dump_st = SymbolTreeDumper(self)
904
990
  dump_st.dump()
905
991
 
906
- def update_module_ast(self):
907
- for node in self._external_func_ast:
908
- self._module_ast.body.append(node)
909
- # Put father asts in front of first ClassDef
910
- index = [type(body) for body in self._module_ast.body].index(ast.ClassDef)
911
- for node in reversed(self._father_class_ast):
912
- self._module_ast.body.insert(index, node)
992
+ def check_body_exist(self, body, code_bodies):
993
+ """Check whether body already exist in code_bodies"""
994
+ # Check import ast node exist by saving import code string to self._tmp_import_strs
995
+ if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
996
+ import_str = astunparse.unparse(body)
997
+ if import_str in self._tmp_import_strs:
998
+ return True
999
+ self._tmp_import_strs.append(import_str)
1000
+ return False
1001
+
1002
+ # Check ClassDef ast node exist by using AstClassFinder
1003
+ if isinstance(body, ast.ClassDef):
1004
+ if sys.version_info >= (3, 9):
1005
+ class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1006
+ else:
1007
+ class_finder = AstClassFinder(ast.Module(body=code_bodies))
1008
+ results = class_finder.find_all(body.name)
1009
+ return bool(results)
1010
+
1011
+ # Check FunctionDef ast node exist by using AstFunctionFinder
1012
+ if isinstance(body, ast.FunctionDef):
1013
+ if sys.version_info >= (3, 9):
1014
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
1015
+ else:
1016
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
1017
+ results = function_finder.find_all(body.name)
1018
+ return bool(results)
1019
+
1020
+ return False
1021
+
1022
+ def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
1023
+ """
1024
+ For the unmodified symbol tree, only one definition code remains in the generated code.
1025
+ Everywhere else calling this symbol tree will use the class in this definition code.
1026
+ """
1027
+ # all modified ast.ClassDef will be exported to code
1028
+ if stree.is_modified():
1029
+ return False
1030
+ # all un-modified ast.ClassDef only keep one instance
1031
+ first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1032
+ if first_cls_name is None:
1033
+ class_ast = stree.get_class_ast()
1034
+ if class_ast:
1035
+ self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
1036
+ return False
1037
+ # Un-modified ast.ClassDef already exist in code_bodies,
1038
+ # replace class name to class name of first un-modified ast.ClassDef.
1039
+ if sys.version_info >= (3, 9):
1040
+ replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1041
+ else:
1042
+ replacer = AstReplacer(ast.Module(body=code_bodies))
1043
+ replacer.replace_all(stree.get_class_ast().name, first_cls_name)
1044
+ self._tmp_replacers.append(replacer)
1045
+ return True
1046
+
1047
+ def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
1048
+ """
1049
+ Convert nodes in stree to code_bodies
1050
+
1051
+ 1. Add import asts into code_bodies
1052
+ 2. Add class, function and other type of asts into code_bodies
1053
+ 3. Add father class asts into code_bodies
1054
+ 4. Add external function asts into code_bodies
1055
+ 5. Add subtrees to code_bodies
1056
+ 5.1 Add subtrees in construct to code_bodies
1057
+ 5.2 Add subtrees in CellContainers to code_bodies
1058
+
1059
+ """
1060
+ # Add import asts into code_bodies
1061
+ for body in stree.get_import_asts():
1062
+ if not self.check_body_exist(body, code_bodies):
1063
+ code_bodies.insert(insert_pos, body)
1064
+ insert_pos += 1
1065
+
1066
+ # Add class, function and other type of asts into code_bodies
1067
+ if stree.get_module_ast():
1068
+ for body in stree.get_module_ast().body:
1069
+ if self.check_body_exist(body, code_bodies):
1070
+ continue
1071
+ if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
1072
+ code_bodies.insert(insert_pos, body)
1073
+ else:
1074
+ code_bodies.append(body)
1075
+
1076
+ # Add father class asts into code_bodies
1077
+ for body in reversed(stree.get_father_class_ast()):
1078
+ if self.check_body_exist(body, code_bodies):
1079
+ # remove exist ast in old position, then insert ast to upper position
1080
+ if sys.version_info >= (3, 9):
1081
+ exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
1082
+ else:
1083
+ exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
1084
+ code_bodies.remove(exist_ast)
1085
+ code_bodies.insert(insert_pos, body)
1086
+
1087
+ # Add external asts into code_bodies
1088
+ for body in stree.get_external_ast():
1089
+ if not self.check_body_exist(body, code_bodies):
1090
+ code_bodies.insert(insert_pos, body)
1091
+ insert_pos += 1
1092
+
1093
+ # Add subtrees to code_bodies
1094
+ for node in stree.get_tree_nodes():
1095
+ sub_stree = node.symbol_tree
1096
+ # Ignore TreeNode create by function in the class
1097
+ if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
1098
+ continue
1099
+ # For the unmodified class, update class name to name of first class
1100
+ if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
1101
+ continue
1102
+ self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
913
1103
 
914
1104
  def get_code(self) -> str:
915
1105
  """
@@ -918,34 +1108,22 @@ class SymbolTree(Observer, Observable):
918
1108
  Returns:
919
1109
  A str represents source code of modified network.
920
1110
  """
921
- self._remove_unused_import()
922
- if self._init_func_ast:
923
- self._remove_unused_field()
924
- self._remove_duplicated_import()
925
- self.update_module_ast()
1111
+ self._tmp_import_strs.clear()
1112
+ self._tmp_unmodified_strees.clear()
1113
+ self._tmp_replacers.clear()
1114
+ code_bodies = []
1115
+ self.convert_stree_to_code_bodies(self, code_bodies)
1116
+ if sys.version_info >= (3, 9):
1117
+ gencode_module = ast.Module(body=code_bodies, type_ignores=[])
1118
+ else:
1119
+ gencode_module = ast.Module(body=code_bodies)
1120
+ SymbolTree._remove_unused_import(gencode_module)
1121
+ SymbolTree._remove_duplicated_import(gencode_module)
926
1122
  ast.fix_missing_locations(self._module_ast)
927
- # Find all ast.ClassDef which can be export to code
928
- # Replace duplicated ast.ClassDef reference in main-ClassDef
929
- seen_class: {type, str} = {}
930
- allow_class_name = [self._class_ast.name]
931
- replacers = []
932
- SymbolTree._find_all_class_in_symboltree(self, seen_class, allow_class_name, replacers)
933
- # Add all non-ClassDef body to gencode_module
934
- # Add all ClassDef in allow_class_name to gencode_module
935
- # Use gencode_module to generate code
936
- bodies = []
937
- for body in self._module_ast.body:
938
- if not isinstance(body, ast.ClassDef):
939
- bodies.append(body)
940
- continue
941
- if body.name in allow_class_name:
942
- bodies.append(body)
943
- gencode_module = ast.Module(body=bodies)
944
- if_fixer = IfFixer()
945
- if_fixer.fix(gencode_module)
1123
+ IfFixer().fix(gencode_module)
946
1124
  code = astunparse.unparse(gencode_module)
947
- # Restore main-ClassDef
948
- for replacer in replacers:
1125
+ # Revert the class name to its original state
1126
+ for replacer in self._tmp_replacers:
949
1127
  replacer.undo_all()
950
1128
  return code
951
1129
 
@@ -979,251 +1157,71 @@ class SymbolTree(Observer, Observable):
979
1157
  f.write(source.encode('utf-8'))
980
1158
  f.flush()
981
1159
 
982
- def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
1160
+ def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
1161
+ node_manager: NodeManager):
983
1162
  """ insert_to_ast_while_insert_node. """
984
- node.set_func(ScopedValue.create_naming_value(node.get_name(), "self"))
985
- node_ast = node.get_ast()
986
- if not isinstance(node_ast, ast.Assign):
987
- raise RuntimeError("Only support insert cell op now")
988
- if isinstance(node, TreeNode):
989
- setattr(self._origin_network, node.get_name(), node.get_instance())
990
- args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
991
- [ScopedValue(ValueType.NamingValue, "", "obj"),
992
- ScopedValue(ValueType.StringValue, "", node.get_name())])
993
- value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
994
- col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
995
-
996
- ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
997
- assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
998
- AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
999
-
1000
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1001
- None if position is None else position.node.get_ast(),
1002
- position.before_node)
1003
- sub_stree: SymbolTree = node.symbol_tree
1004
- from .symbol_tree_builder import SymbolTreeBuilder
1005
- SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
1163
+ if new_node.get_node_type() == NodeType.Input:
1164
+ # insert a new input
1165
+ self._inputs.append(new_node)
1166
+ ast_construct = self.get_ast_root()
1167
+ arg: str = new_node.get_targets()[0].value
1168
+ ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
1169
+ AstModifier.append_arg_to_function(ast_construct, ast_arg)
1006
1170
  else:
1007
- AstModifier.insert_assign_to_function(self._init_func_ast,
1008
- targets=[ScopedValue(ValueType.NamingValue, "self", node.get_name())],
1009
- expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
1010
- args=[ScopedValue(ValueType.NamingValue, "", "obj"),
1011
- ScopedValue(ValueType.StringValue, "", node.get_name())])
1012
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1013
- None if position is None else position.node.get_ast(),
1014
- position.before_node)
1015
- setattr(self._origin_network, node.get_name(), node.get_instance())
1016
-
1017
- def _remove_unused_import(self):
1018
- """remove unused import in self._module_ast"""
1019
- str_checker = StrChecker(self._module_ast)
1020
- for i in range(len(self._module_ast.body) - 1, -1, -1):
1021
- body = self._module_ast.body[i]
1022
- if not isinstance(body, (ast.Import, ast.ImportFrom)):
1023
- continue
1024
- if isinstance(body, ast.Import):
1025
- continue
1026
- if isinstance(body, ast.ImportFrom) and body.module == "cell":
1027
- self._module_ast.body.remove(body)
1028
- continue
1029
- for alias in body.names:
1030
- name = alias.asname if alias.asname else alias.name
1031
- if not str_checker.check(name):
1032
- if len(body.names) == 1:
1033
- self._module_ast.body.remove(body)
1034
- i += 1
1035
- else:
1036
- body.names.remove(alias)
1037
-
1038
- def _replace_container_node(self, old_node, new_nodes):
1039
- cellcontainer = getattr(old_node, "container")
1040
- index = cellcontainer.node_list.index(old_node)
1041
- for n in reversed(new_nodes):
1042
- cellcontainer.insert(index, n)
1043
- index = cellcontainer.node_list.index(old_node)
1044
- cellcontainer.erase(old_node)
1045
-
1046
- def _filter_out_to_delete_field(self, to_delete_field):
1047
- """filter out used field from `to_delete_field`"""
1048
- for func_def in self._class_ast.body:
1049
- if not isinstance(func_def, ast.FunctionDef):
1050
- continue
1051
- if func_def.name != "__init__":
1052
- to_delete_to_delete_keys = []
1053
- property_checker = CheckPropertyIsUsed(func_def)
1054
- for key, _ in self._deleted_field.items():
1055
- if property_checker.check("self", key):
1056
- to_delete_to_delete_keys.append(key)
1057
- property_checker = CheckPropertyIsUsed(func_def)
1058
- for key in to_delete_to_delete_keys:
1059
- self._deleted_field.pop(key)
1171
+ # insert a new assign statement
1172
+ ast_assign = new_node.get_ast()
1173
+ if ast_assign is None:
1174
+ func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
1175
+ new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
1176
+ ast_assign = new_node.update_ast_node()
1177
+ if not isinstance(ast_assign, ast.Assign):
1178
+ raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
1179
+ # Save instance into _origin_network.
1180
+ setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
1181
+ # Insert ast to __init__ function
1182
+ if isinstance(new_node, TreeNode):
1183
+ init_code = f"self.{new_node.get_name()} = " \
1184
+ f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
1060
1185
  else:
1061
- for body in func_def.body:
1062
- if not isinstance(body, ast.If):
1063
- continue
1064
- test = body.test
1065
- field_finder = FieldFinder(test)
1066
- to_delete_to_delete_keys = []
1067
- for key, _ in self._deleted_field.items():
1068
- if field_finder.check(key):
1069
- to_delete_to_delete_keys.append(key)
1070
- for key in to_delete_to_delete_keys:
1071
- self._deleted_field.pop(key)
1072
-
1073
- def _remove_unused_field(self):
1074
- """remove unused field in __init__ function"""
1075
- multi_targets = []
1076
- self._deleted_field = {}
1077
- for index, body in enumerate(self._init_func_ast.body):
1078
- if not isinstance(body, ast.Assign):
1079
- continue
1080
- targets = body.targets
1081
- for target in targets:
1082
- if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
1083
- and target.value.id == "self":
1084
- self._deleted_field[target.attr] = index
1085
- if len(targets) > 1:
1086
- multi_targets.append(index)
1087
- self._filter_out_to_delete_field(self._deleted_field)
1088
- for i in range(len(self._init_func_ast.body) - 1, -1, -1):
1089
- if i in self._deleted_field.values():
1090
- if i in multi_targets:
1091
- raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
1092
- AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
1093
- ast.fix_missing_locations(self._init_func_ast)
1094
-
1095
- def _remove_duplicated_import(self):
1096
- """Remove duplicated import of 'net'."""
1097
- imports = []
1098
- for body in self._module_ast.body:
1099
- if isinstance(body, (ast.ImportFrom, ast.Import)):
1100
- import_str = astunparse.unparse(body)
1101
- if import_str not in imports:
1102
- imports.append(import_str)
1103
- else:
1104
- self._module_ast.body.remove(body)
1186
+ init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
1187
+ init_ast = ast.parse(init_code).body[0]
1188
+ AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
1189
+ # Insert ast to construct_function/class_internal_function
1190
+ ast_base_node = base_node.get_ast() if base_node else None
1191
+ ast_functiondef = node_manager.get_ast_functiondef()
1192
+ if not ast_functiondef:
1193
+ raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
1194
+ "when inserting the ast.")
1195
+ AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
1105
1196
 
1106
1197
  def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1107
1198
  if isinstance(node_or_name, str):
1108
1199
  return self.get_node(node_or_name)
1109
1200
  return node_or_name
1110
1201
 
1111
- def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
1112
- """
1113
- Insert a node-tree into SymbolTree.
1114
- Note:
1115
- Inputs of intra sub-tree nodes need to be welly set.
1116
-
1117
- Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
1118
-
1119
- Args:
1120
- position (Position): A Position indicates an insert position point.
1121
- root (Node): An instance of node as root of node-tree to be inserted in.
1122
- insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
1123
- True.
1124
-
1125
- Returns:
1126
- An instance of node as root node of node-tree which has been inserted into SymbolTree.
1127
-
1128
- Raises:
1129
- RuntimeError: If 'position' is not in current SymbolTree.
1130
- """
1131
-
1132
- # if position not in current SymbolTree
1133
- if position.symbol_tree is not self:
1134
- raise RuntimeError("Position is not in current SymbolTree: ", position)
1135
-
1136
- queue: [Node] = [root]
1137
- todos: [] = []
1138
- inputs_list: [] = []
1139
- while queue:
1140
- cur_node = queue.pop(0)
1141
- if cur_node in todos:
1142
- continue
1143
- todos.append(cur_node)
1144
- node_inputs = cur_node.get_inputs()
1145
- inputs_list.append(node_inputs)
1146
- for node_input in node_inputs:
1147
- if node_input is not None:
1148
- queue.append(node_input)
1149
- todos.reverse()
1150
- inputs_list.reverse()
1151
- for index, todo in enumerate(todos):
1152
- self.insert_node(position, todo, insert_to_ast)
1153
- position = self.after(todo)
1154
- # relink input of node
1155
- original_inputs = inputs_list[index]
1156
- for arg_idx, original_input in enumerate(original_inputs):
1157
- if original_input is not None:
1158
- self.set_node_arg_by_node(todo, arg_idx, original_input)
1159
- return root
1160
-
1161
- def _add_node2nodes(self, node: Node):
1162
- """
1163
- Add `node` to `_nodes` dict.
1164
-
1165
- Args:
1166
- node (Node): A Node to be added into `_nodes`.
1167
-
1168
- Raises:
1169
- RuntimeError: If name of the node is duplicated.
1170
- """
1171
- node_name = node.get_name()
1172
- if self._nodes.get(node_name) is not None:
1173
- raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
1174
- node)
1175
- self._nodes[node_name] = node
1176
-
1177
- def _insert_node(self, position: Optional[Position], node: Node):
1178
- """
1179
- Insert a node into SymbolTree.
1180
- 1. Add `node` to `_nodes`.
1181
- 2. Insert `node` to node list(source code order).
1182
- 3. Update topological relation and update inputs of `node`.
1183
-
1184
- Args:
1185
- position ([Position, optional]): Indicates node insert position. Position is None when inserting first node
1186
- of SymbolTree.
1187
- node (Node): A Node to be inserted into SymbolTree.
1188
-
1189
- Raises:
1190
- RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
1191
- unless inserting first node.
1192
- """
1193
- if position is None:
1194
- if self._nodes:
1195
- raise RuntimeError("self._nodes should be empty")
1196
- self._head = node
1197
- else:
1198
- if position.before_node:
1199
- position.node.insert_before(node)
1200
- else:
1201
- position.node.insert_after(node)
1202
- self._tail = node
1203
- self._add_node2nodes(node)
1204
- self._topo_mgr.on_insert_node(node)
1205
- node.set_belong_symbol_tree(self)
1206
-
1207
1202
  def _handle_custom_obj_in_normalized_args(self, node: Node):
1208
1203
  """
1209
- Convert CustomObjValue type argument to NamingValue type argument by storing custom object in global_vars dict.
1204
+ Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
1210
1205
 
1211
1206
  Args:
1212
1207
  node (Node): A Node whose arguments and keyword arguments to be handled.
1213
1208
  """
1214
- result: {str, ScopedValue} = {}
1215
- for arg, value in node.get_normalized_args().items():
1209
+ normalized_args: {str, ScopedValue} = {}
1210
+ for key, value in node.get_normalized_args().items():
1216
1211
  if not isinstance(value, ScopedValue):
1217
1212
  raise TypeError("value should be ScopedValue, got: ", type(value))
1218
1213
  if value.type == ValueType.CustomObjValue:
1219
- field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
1220
- setattr(self._origin_network, field, value.value)
1221
- init_targets = [ScopedValue.create_naming_value(field, "self")]
1222
- AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
1223
- result[arg] = init_targets[0]
1214
+ # Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
1215
+ arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
1216
+ setattr(self._origin_network, arg_name, value.value)
1217
+ # Add new code to __init__(): self.arg_name = obj.arg_name
1218
+ new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
1219
+ self._init_func_ast.body.append(new_ast)
1220
+ # Modify node's normalized_args: CustomObjValue -> self.arg_name
1221
+ normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
1224
1222
  else:
1225
- result[arg] = value
1226
- node.set_normalized_args(result)
1223
+ normalized_args[key] = value
1224
+ node.set_normalized_args(normalized_args)
1227
1225
 
1228
1226
  def _get_cls_through_file(self):
1229
1227
  """
@@ -1235,12 +1233,14 @@ class SymbolTree(Observer, Observable):
1235
1233
  Returns:
1236
1234
  A class handle.
1237
1235
  """
1238
- self._update_container()
1239
1236
  file_path = os.getcwd()
1240
1237
  file_path = os.path.join(file_path, "rewritten_network")
1241
1238
  if not os.path.exists(file_path):
1242
- os.mkdir(file_path)
1243
- file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
1239
+ try:
1240
+ os.mkdir(file_path, mode=0o700)
1241
+ except FileExistsError:
1242
+ pass
1243
+ file_name = f"{self._opt_cls_name}_{id(self)}.py"
1244
1244
  network_file = os.path.join(file_path, file_name)
1245
1245
  with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1246
1246
  source = self.get_code()
@@ -1277,21 +1277,6 @@ class SymbolTree(Observer, Observable):
1277
1277
  self._modified = True
1278
1278
  self.changed(event)
1279
1279
 
1280
- def _update_container(self):
1281
- """Update instance of node in container."""
1282
- for node in self.nodes():
1283
- index = 0
1284
- if node.get_node_type() == NodeType.CellContainer:
1285
- for n in node.node_list:
1286
- if not n.valid:
1287
- continue
1288
- if n.get_node_type() == NodeType.Tree:
1289
- obj = n.symbol_tree.get_network()
1290
- node.get_instance()[index] = obj
1291
- else:
1292
- node.get_instance()[index] = n.get_instance()
1293
- index += 1
1294
-
1295
1280
  def _cal_difference_set(self, input, other):
1296
1281
  """Calculate different set of two sets."""
1297
1282
  set1 = set(input)
@@ -1313,43 +1298,3 @@ class SymbolTree(Observer, Observable):
1313
1298
  primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1314
1299
  for p in primitives:
1315
1300
  new_net._primitives[p] = self._origin_network._primitives[p]
1316
-
1317
- def _create_call_function(self, func, targets, args, kwargs):
1318
- """
1319
- Create a Node object and generate the execution code to insert into the source code.
1320
- The source code calls the 'func' function with 'args' and' kwargs' as parameters.
1321
-
1322
- Args:
1323
- func (FunctionType) - The function to be called.
1324
- targets (list [str]) - indicates the output name. As the output of the node in the source code.
1325
- args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
1326
- code. The default value is None, which means there is no parameter input in the cell.
1327
- kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
1328
- input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
1329
- code as the 'kwargs' in the statement expression. The default value is None, which means there is no
1330
- 'kwargs' input.
1331
-
1332
- Returns:
1333
- An instance of `Node`.
1334
- """
1335
- if not isinstance(func, types.FunctionType):
1336
- raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
1337
-
1338
- _package = func.__globals__['__package__']
1339
- func_name = ".".join([_package, func.__name__]) if _package else func.__name__
1340
-
1341
- ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
1342
- scope_targets = [ScopedValue.create_naming_value(targets[0])]
1343
- scope_func = ScopedValue.create_naming_value(func_name, "")
1344
- call_args = list()
1345
- for arg in args:
1346
- if isinstance(arg, Node):
1347
- call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
1348
- else:
1349
- call_args.append(ScopedValue.create_variable_value(arg))
1350
- call_kwargs = {}
1351
- for k, v in kwargs.items():
1352
- call_kwargs[k] = ScopedValue.create_variable_value(v)
1353
- node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
1354
- call_kwargs)
1355
- return node