mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (870) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Third_Party_Open_Source_Software_Notice +2 -2
  3. mindspore/__init__.py +5 -2
  4. mindspore/_akg/akg/build_module.py +5 -6
  5. mindspore/_akg/akg/composite/build_module.py +49 -16
  6. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  7. mindspore/_akg/akg/config/repository.json +195 -0
  8. mindspore/_akg/akg/global_configs.py +5 -1
  9. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  10. mindspore/_akg/akg/tvm/api.py +4 -3
  11. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  12. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  13. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  14. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  15. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  16. mindspore/_akg/akg/tvm/build_module.py +16 -1
  17. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  18. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  19. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  20. mindspore/_akg/akg/tvm/module.py +1 -2
  21. mindspore/_akg/akg/tvm/stmt.py +2 -2
  22. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  23. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  24. mindspore/_akg/akg/utils/op_dsl.py +17 -1
  25. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  26. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  27. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  28. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  29. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  30. mindspore/_check_jit_forbidden_api.py +5 -1
  31. mindspore/_checkparam.py +79 -62
  32. mindspore/_extends/graph_kernel/__init__.py +0 -1
  33. mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
  34. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  35. mindspore/_extends/graph_kernel/splitter.py +1 -9
  36. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
  37. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  38. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  39. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
  40. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
  41. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  42. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  43. mindspore/_extends/parse/__init__.py +19 -17
  44. mindspore/_extends/parse/namespace.py +7 -36
  45. mindspore/_extends/parse/parser.py +375 -189
  46. mindspore/_extends/parse/resources.py +36 -41
  47. mindspore/_extends/parse/standard_method.py +350 -245
  48. mindspore/_extends/parse/trope.py +2 -12
  49. mindspore/_extends/remote/kernel_build_server.py +24 -7
  50. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  51. mindspore/_install_custom.py +43 -0
  52. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  53. mindspore/amp.py +85 -19
  54. mindspore/bin/cache_admin +0 -0
  55. mindspore/bin/cache_server +0 -0
  56. mindspore/boost/base.py +2 -2
  57. mindspore/boost/boost.py +27 -32
  58. mindspore/boost/boost_cell_wrapper.py +37 -13
  59. mindspore/boost/grad_accumulation.py +1 -1
  60. mindspore/boost/grad_freeze.py +34 -6
  61. mindspore/boost/group_loss_scale_manager.py +15 -14
  62. mindspore/boost/less_batch_normalization.py +28 -3
  63. mindspore/common/__init__.py +15 -11
  64. mindspore/common/_auto_dynamic.py +68 -0
  65. mindspore/common/_jit_fallback_utils.py +111 -0
  66. mindspore/common/_register_for_adapter.py +17 -5
  67. mindspore/common/_register_for_tensor.py +2 -2
  68. mindspore/common/_stub_tensor.py +18 -15
  69. mindspore/common/_utils.py +31 -7
  70. mindspore/common/api.py +269 -101
  71. mindspore/common/auto_dynamic_shape.py +498 -0
  72. mindspore/common/dtype.py +61 -21
  73. mindspore/common/dump.py +9 -7
  74. mindspore/common/initializer.py +106 -76
  75. mindspore/common/jit_config.py +35 -14
  76. mindspore/common/lazy_inline.py +187 -0
  77. mindspore/common/mindir_util.py +101 -0
  78. mindspore/common/mutable.py +10 -13
  79. mindspore/common/parameter.py +246 -55
  80. mindspore/common/seed.py +13 -7
  81. mindspore/common/sparse_tensor.py +29 -33
  82. mindspore/common/tensor.py +907 -251
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +84 -4
  85. mindspore/communication/management.py +160 -88
  86. mindspore/config/op_info.config +99 -75
  87. mindspore/config/super_bar_config.json +36 -4
  88. mindspore/context.py +526 -219
  89. mindspore/dataset/__init__.py +9 -46
  90. mindspore/dataset/audio/__init__.py +4 -19
  91. mindspore/dataset/audio/transforms.py +545 -233
  92. mindspore/dataset/audio/utils.py +21 -18
  93. mindspore/dataset/callback/ds_callback.py +42 -13
  94. mindspore/dataset/core/config.py +158 -100
  95. mindspore/dataset/core/validator_helpers.py +1 -63
  96. mindspore/dataset/debug/debug_hook.py +45 -13
  97. mindspore/dataset/debug/pre_defined_hook.py +5 -5
  98. mindspore/dataset/engine/__init__.py +0 -5
  99. mindspore/dataset/engine/cache_client.py +38 -15
  100. mindspore/dataset/engine/datasets.py +615 -278
  101. mindspore/dataset/engine/datasets_audio.py +154 -283
  102. mindspore/dataset/engine/datasets_standard_format.py +104 -116
  103. mindspore/dataset/engine/datasets_text.py +443 -326
  104. mindspore/dataset/engine/datasets_user_defined.py +251 -164
  105. mindspore/dataset/engine/datasets_vision.py +839 -1443
  106. mindspore/dataset/engine/iterators.py +11 -4
  107. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
  108. mindspore/dataset/engine/obs/util.py +3 -0
  109. mindspore/dataset/engine/offload.py +6 -6
  110. mindspore/dataset/engine/queue.py +15 -14
  111. mindspore/dataset/engine/samplers.py +39 -23
  112. mindspore/dataset/engine/serializer_deserializer.py +22 -6
  113. mindspore/dataset/engine/validators.py +21 -331
  114. mindspore/dataset/text/__init__.py +5 -33
  115. mindspore/dataset/text/transforms.py +334 -165
  116. mindspore/dataset/text/utils.py +215 -145
  117. mindspore/dataset/transforms/__init__.py +1 -1
  118. mindspore/dataset/transforms/c_transforms.py +3 -2
  119. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  120. mindspore/dataset/transforms/transforms.py +174 -71
  121. mindspore/dataset/utils/browse_dataset.py +25 -17
  122. mindspore/dataset/utils/line_reader.py +24 -21
  123. mindspore/dataset/vision/__init__.py +5 -26
  124. mindspore/dataset/vision/c_transforms.py +177 -165
  125. mindspore/dataset/vision/py_transforms.py +114 -119
  126. mindspore/dataset/vision/py_transforms_util.py +54 -51
  127. mindspore/dataset/vision/transforms.py +1127 -381
  128. mindspore/dataset/vision/utils.py +54 -38
  129. mindspore/dataset/vision/validators.py +12 -2
  130. mindspore/experimental/map_parameter.py +38 -4
  131. mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
  132. mindspore/experimental/optim/adam.py +192 -0
  133. mindspore/experimental/optim/adamw.py +181 -0
  134. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  135. mindspore/experimental/optim/optimizer.py +252 -0
  136. mindspore/experimental/optim/sgd.py +147 -0
  137. mindspore/gen_ops.py +273 -0
  138. mindspore/include/OWNERS +1 -2
  139. mindspore/include/api/context.h +21 -1
  140. mindspore/include/api/data_type.h +2 -1
  141. mindspore/include/api/graph.h +0 -15
  142. mindspore/include/api/kernel.h +2 -0
  143. mindspore/include/api/kernel_api.h +37 -12
  144. mindspore/include/api/model.h +29 -42
  145. mindspore/include/api/model_group.h +14 -3
  146. mindspore/include/api/model_parallel_runner.h +18 -2
  147. mindspore/include/api/serialization.h +26 -0
  148. mindspore/include/api/status.h +1 -0
  149. mindspore/include/api/types.h +38 -4
  150. mindspore/include/c_api/ms/abstract.h +67 -0
  151. mindspore/include/c_api/ms/attribute.h +197 -0
  152. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  153. mindspore/include/c_api/ms/base/macros.h +32 -0
  154. mindspore/include/c_api/ms/base/status.h +33 -0
  155. mindspore/include/c_api/ms/base/types.h +282 -0
  156. mindspore/include/c_api/ms/context.h +102 -0
  157. mindspore/include/c_api/ms/graph.h +160 -0
  158. mindspore/include/c_api/ms/node.h +606 -0
  159. mindspore/include/c_api/ms/tensor.h +161 -0
  160. mindspore/include/c_api/ms/value.h +84 -0
  161. mindspore/include/c_api/status_c.h +3 -0
  162. mindspore/include/dataset/constants.h +6 -12
  163. mindspore/include/dataset/execute.h +23 -13
  164. mindspore/include/dataset/text.h +26 -26
  165. mindspore/include/dataset/transforms.h +25 -31
  166. mindspore/include/dataset/vision.h +60 -60
  167. mindspore/include/dataset/vision_ascend.h +5 -6
  168. mindspore/include/dataset/vision_lite.h +17 -17
  169. mindspore/include/mindapi/base/format.h +0 -1
  170. mindspore/include/mindapi/base/type_id.h +2 -1
  171. mindspore/include/mindapi/base/types.h +5 -1
  172. mindspore/lib/libdnnl.so.2 +0 -0
  173. mindspore/lib/libjemalloc.so.2 +0 -0
  174. mindspore/lib/libmindspore.so +0 -0
  175. mindspore/lib/libmindspore_backend.so +0 -0
  176. mindspore/lib/libmindspore_common.so +0 -0
  177. mindspore/lib/libmindspore_core.so +0 -0
  178. mindspore/lib/libmindspore_glog.so.0 +0 -0
  179. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  180. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  181. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  182. mindspore/lib/libmindspore_shared_lib.so +0 -0
  183. mindspore/lib/libmpi_adapter.so +0 -0
  184. mindspore/lib/libnnacl.so +0 -0
  185. mindspore/lib/libopencv_core.so.4.5 +0 -0
  186. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  187. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  188. mindspore/lib/libps_cache.so +0 -0
  189. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  190. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  191. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  192. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  193. mindspore/lib/plugin/ascend/libakg.so +0 -0
  194. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  195. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  196. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  197. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  198. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  199. mindspore/lib/plugin/cpu/libakg.so +0 -0
  200. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  201. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  202. mindspore/log.py +9 -6
  203. mindspore/mindrecord/filereader.py +33 -4
  204. mindspore/mindrecord/filewriter.py +70 -35
  205. mindspore/mindrecord/mindpage.py +40 -34
  206. mindspore/mindrecord/shardreader.py +1 -1
  207. mindspore/mindrecord/shardsegment.py +1 -1
  208. mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
  209. mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
  210. mindspore/mindrecord/tools/csv_to_mr.py +29 -13
  211. mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
  212. mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
  213. mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
  214. mindspore/nn/cell.py +463 -169
  215. mindspore/nn/dynamic_lr.py +47 -43
  216. mindspore/nn/layer/activation.py +225 -82
  217. mindspore/nn/layer/basic.py +121 -79
  218. mindspore/nn/layer/channel_shuffle.py +21 -21
  219. mindspore/nn/layer/combined.py +33 -26
  220. mindspore/nn/layer/container.py +277 -22
  221. mindspore/nn/layer/conv.py +441 -304
  222. mindspore/nn/layer/dense.py +19 -13
  223. mindspore/nn/layer/embedding.py +62 -49
  224. mindspore/nn/layer/flash_attention.py +264 -0
  225. mindspore/nn/layer/image.py +50 -39
  226. mindspore/nn/layer/math.py +62 -51
  227. mindspore/nn/layer/normalization.py +219 -167
  228. mindspore/nn/layer/padding.py +58 -70
  229. mindspore/nn/layer/pooling.py +334 -287
  230. mindspore/nn/layer/rnn_cells.py +53 -38
  231. mindspore/nn/layer/rnns.py +59 -56
  232. mindspore/nn/layer/thor_layer.py +52 -44
  233. mindspore/nn/layer/timedistributed.py +6 -4
  234. mindspore/nn/layer/transformer.py +284 -164
  235. mindspore/nn/learning_rate_schedule.py +34 -25
  236. mindspore/nn/loss/__init__.py +3 -2
  237. mindspore/nn/loss/loss.py +554 -311
  238. mindspore/nn/optim/ada_grad.py +12 -9
  239. mindspore/nn/optim/adadelta.py +14 -11
  240. mindspore/nn/optim/adafactor.py +19 -16
  241. mindspore/nn/optim/adam.py +62 -47
  242. mindspore/nn/optim/adamax.py +13 -10
  243. mindspore/nn/optim/adasum.py +12 -8
  244. mindspore/nn/optim/asgd.py +10 -9
  245. mindspore/nn/optim/ftrl.py +20 -17
  246. mindspore/nn/optim/lamb.py +16 -12
  247. mindspore/nn/optim/lars.py +8 -6
  248. mindspore/nn/optim/lazyadam.py +25 -20
  249. mindspore/nn/optim/momentum.py +10 -7
  250. mindspore/nn/optim/optimizer.py +61 -9
  251. mindspore/nn/optim/proximal_ada_grad.py +14 -13
  252. mindspore/nn/optim/rmsprop.py +17 -13
  253. mindspore/nn/optim/rprop.py +30 -17
  254. mindspore/nn/optim/sgd.py +40 -23
  255. mindspore/nn/optim/thor.py +24 -26
  256. mindspore/nn/probability/bijector/bijector.py +11 -11
  257. mindspore/nn/probability/bijector/exp.py +1 -1
  258. mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
  259. mindspore/nn/probability/bijector/invert.py +1 -1
  260. mindspore/nn/probability/bijector/power_transform.py +29 -29
  261. mindspore/nn/probability/bijector/scalar_affine.py +3 -3
  262. mindspore/nn/probability/bijector/softplus.py +5 -5
  263. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
  264. mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
  265. mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
  266. mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
  267. mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
  268. mindspore/nn/probability/distribution/_utils/utils.py +1 -1
  269. mindspore/nn/probability/distribution/bernoulli.py +9 -9
  270. mindspore/nn/probability/distribution/beta.py +8 -8
  271. mindspore/nn/probability/distribution/categorical.py +23 -15
  272. mindspore/nn/probability/distribution/cauchy.py +5 -6
  273. mindspore/nn/probability/distribution/distribution.py +3 -3
  274. mindspore/nn/probability/distribution/exponential.py +4 -4
  275. mindspore/nn/probability/distribution/gamma.py +10 -10
  276. mindspore/nn/probability/distribution/geometric.py +8 -8
  277. mindspore/nn/probability/distribution/gumbel.py +8 -9
  278. mindspore/nn/probability/distribution/half_normal.py +5 -5
  279. mindspore/nn/probability/distribution/laplace.py +5 -5
  280. mindspore/nn/probability/distribution/log_normal.py +12 -11
  281. mindspore/nn/probability/distribution/logistic.py +8 -8
  282. mindspore/nn/probability/distribution/normal.py +6 -5
  283. mindspore/nn/probability/distribution/poisson.py +10 -11
  284. mindspore/nn/probability/distribution/student_t.py +8 -9
  285. mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
  286. mindspore/nn/probability/distribution/uniform.py +11 -11
  287. mindspore/nn/reinforcement/tensor_array.py +2 -2
  288. mindspore/nn/sparse/sparse.py +9 -9
  289. mindspore/nn/wrap/cell_wrapper.py +188 -63
  290. mindspore/nn/wrap/grad_reducer.py +21 -12
  291. mindspore/nn/wrap/loss_scale.py +136 -49
  292. mindspore/numpy/__init__.py +4 -4
  293. mindspore/numpy/array_creations.py +55 -56
  294. mindspore/numpy/array_ops.py +134 -35
  295. mindspore/numpy/logic_ops.py +66 -20
  296. mindspore/numpy/math_ops.py +142 -139
  297. mindspore/numpy/utils_const.py +2 -2
  298. mindspore/offline_debug/convert_async.py +2 -2
  299. mindspore/ops/_grad_experimental/__init__.py +7 -5
  300. mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
  301. mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
  302. mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
  303. mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
  304. mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
  305. mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
  306. mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
  307. mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
  308. mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
  309. mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
  310. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
  311. mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
  312. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  313. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  314. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
  315. mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
  316. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
  317. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
  318. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
  319. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
  320. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  321. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
  322. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
  323. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
  324. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
  325. mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
  326. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  327. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  328. mindspore/ops/_op_impl/aicpu/cast.py +52 -0
  329. mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
  330. mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
  331. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  332. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
  333. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  334. mindspore/ops/_op_impl/aicpu/eye.py +4 -4
  335. mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
  336. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
  337. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  338. mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
  339. mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
  340. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  341. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  342. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  343. mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
  344. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
  345. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  346. mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
  347. mindspore/ops/_op_impl/aicpu/median.py +1 -0
  348. mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
  349. mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
  350. mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
  351. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
  352. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  353. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  354. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  355. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  356. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  357. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
  358. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
  359. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
  360. mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
  361. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  362. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  363. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  364. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  365. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
  366. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
  367. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  368. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  369. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  370. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  371. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  372. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
  373. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
  374. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
  375. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
  376. mindspore/ops/_op_impl/tbe/__init__.py +6 -4
  377. mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
  378. mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
  379. mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
  380. mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
  381. mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
  382. mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
  383. mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
  384. mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
  385. mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
  386. mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
  387. mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
  388. mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
  389. mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
  390. mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
  391. mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
  392. mindspore/ops/_op_impl/tbe/im2col.py +4 -4
  393. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  394. mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
  395. mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
  396. mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
  397. mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
  398. mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
  399. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  400. mindspore/ops/_primitive_cache.py +1 -1
  401. mindspore/ops/_tracefunc.py +241 -0
  402. mindspore/ops/_utils/utils.py +10 -2
  403. mindspore/ops/_vmap/vmap_array_ops.py +5 -3
  404. mindspore/ops/_vmap/vmap_base.py +5 -4
  405. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  406. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  407. mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
  408. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  409. mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
  410. mindspore/ops/arg_dtype_cast.py +54 -0
  411. mindspore/ops/composite/__init__.py +7 -5
  412. mindspore/ops/composite/base.py +78 -34
  413. mindspore/ops/composite/math_ops.py +5 -695
  414. mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
  415. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
  416. mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
  417. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
  418. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
  419. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
  420. mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
  421. mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
  422. mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
  423. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
  424. mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
  425. mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
  426. mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
  427. mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
  428. mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
  429. mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
  430. mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
  431. mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
  432. mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
  433. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  434. mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
  435. mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
  436. mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
  437. mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
  438. mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
  439. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  440. mindspore/ops/deprecated.py +304 -0
  441. mindspore/ops/function/__init__.py +41 -4
  442. mindspore/ops/function/array_func.py +1108 -467
  443. mindspore/ops/function/clip_func.py +94 -27
  444. mindspore/ops/function/debug_func.py +3 -1
  445. mindspore/ops/function/grad/grad_func.py +82 -73
  446. mindspore/ops/function/image_func.py +28 -12
  447. mindspore/ops/function/linalg_func.py +135 -39
  448. mindspore/ops/function/math_func.py +3779 -894
  449. mindspore/ops/function/nn_func.py +1584 -657
  450. mindspore/ops/function/parameter_func.py +13 -3
  451. mindspore/ops/function/random_func.py +247 -153
  452. mindspore/ops/function/sparse_func.py +14 -11
  453. mindspore/ops/function/sparse_unary_func.py +173 -47
  454. mindspore/ops/function/spectral_func.py +8 -4
  455. mindspore/ops/function/vmap_func.py +8 -7
  456. mindspore/ops/functional.py +47 -16
  457. mindspore/ops/op_info_register.py +346 -86
  458. mindspore/ops/operations/__init__.py +38 -22
  459. mindspore/ops/operations/_grad_ops.py +145 -149
  460. mindspore/ops/operations/_inner_ops.py +298 -56
  461. mindspore/ops/operations/_ms_kernel.py +3 -3
  462. mindspore/ops/operations/_quant_ops.py +24 -28
  463. mindspore/ops/operations/_rl_inner_ops.py +9 -7
  464. mindspore/ops/operations/_scalar_ops.py +115 -0
  465. mindspore/ops/operations/_sequence_ops.py +148 -10
  466. mindspore/ops/operations/_tensor_array.py +1 -1
  467. mindspore/ops/operations/_thor_ops.py +2 -2
  468. mindspore/ops/operations/array_ops.py +1239 -561
  469. mindspore/ops/operations/comm_ops.py +166 -90
  470. mindspore/ops/operations/control_ops.py +3 -3
  471. mindspore/ops/operations/custom_ops.py +124 -102
  472. mindspore/ops/operations/debug_ops.py +24 -11
  473. mindspore/ops/operations/image_ops.py +86 -71
  474. mindspore/ops/operations/inner_ops.py +18 -13
  475. mindspore/ops/operations/linalg_ops.py +30 -11
  476. mindspore/ops/operations/math_ops.py +1730 -435
  477. mindspore/ops/operations/nn_ops.py +1953 -943
  478. mindspore/ops/operations/other_ops.py +65 -43
  479. mindspore/ops/operations/random_ops.py +258 -98
  480. mindspore/ops/operations/rl_ops.py +4 -36
  481. mindspore/ops/operations/sparse_ops.py +38 -33
  482. mindspore/ops/operations/spectral_ops.py +8 -4
  483. mindspore/ops/primitive.py +66 -44
  484. mindspore/ops/signature.py +5 -5
  485. mindspore/parallel/_auto_parallel_context.py +80 -19
  486. mindspore/parallel/_cost_model_context.py +42 -0
  487. mindspore/parallel/_offload_context.py +162 -72
  488. mindspore/parallel/_parallel_serialization.py +2 -2
  489. mindspore/parallel/_ps_context.py +16 -4
  490. mindspore/parallel/_recovery_context.py +2 -1
  491. mindspore/parallel/_tensor.py +15 -13
  492. mindspore/parallel/_transformer/layers.py +8 -6
  493. mindspore/parallel/_transformer/loss.py +1 -0
  494. mindspore/parallel/_transformer/moe.py +7 -7
  495. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  496. mindspore/parallel/_transformer/transformer.py +34 -14
  497. mindspore/parallel/_utils.py +36 -14
  498. mindspore/parallel/algo_parameter_config.py +114 -20
  499. mindspore/parallel/checkpoint_transform.py +16 -18
  500. mindspore/parallel/shard.py +16 -13
  501. mindspore/profiler/__init__.py +1 -1
  502. mindspore/profiler/common/struct_type.py +3 -3
  503. mindspore/profiler/common/util.py +3 -2
  504. mindspore/profiler/envprofiling.py +11 -4
  505. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  506. mindspore/profiler/parser/ascend_flops_generator.py +94 -0
  507. mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
  508. mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
  509. mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
  510. mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
  511. mindspore/profiler/parser/ascend_op_generator.py +276 -0
  512. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  513. mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
  514. mindspore/profiler/parser/base_timeline_generator.py +11 -7
  515. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
  516. mindspore/profiler/parser/flops_parser.py +15 -11
  517. mindspore/profiler/parser/framework_parser.py +92 -73
  518. mindspore/profiler/parser/hccl_parser.py +16 -12
  519. mindspore/profiler/parser/integrator.py +22 -11
  520. mindspore/profiler/parser/memory_usage_parser.py +36 -11
  521. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  522. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  523. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  524. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  525. mindspore/profiler/parser/optime_parser.py +1 -1
  526. mindspore/profiler/parser/profiler_info.py +4 -5
  527. mindspore/profiler/parser/step_trace_parser.py +11 -14
  528. mindspore/profiler/profiling.py +678 -377
  529. mindspore/rewrite/api/node.py +211 -54
  530. mindspore/rewrite/api/node_type.py +5 -0
  531. mindspore/rewrite/api/pattern_engine.py +22 -23
  532. mindspore/rewrite/api/scoped_value.py +20 -17
  533. mindspore/rewrite/api/symbol_tree.py +252 -106
  534. mindspore/rewrite/api/tree_node_helper.py +3 -0
  535. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  536. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  537. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  538. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
  539. mindspore/rewrite/common/rewrite_elog.py +5 -1
  540. mindspore/rewrite/namer.py +51 -51
  541. mindspore/rewrite/namespace.py +14 -5
  542. mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
  543. mindspore/rewrite/node/call_function.py +79 -0
  544. mindspore/rewrite/node/cell_container.py +135 -0
  545. mindspore/rewrite/node/control_flow.py +88 -0
  546. mindspore/rewrite/{node.py → node/node.py} +313 -247
  547. mindspore/rewrite/node/node_manager.py +254 -0
  548. mindspore/rewrite/node/node_topological_manager.py +243 -0
  549. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  550. mindspore/rewrite/parsers/assign_parser.py +225 -239
  551. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  552. mindspore/rewrite/parsers/class_def_parser.py +179 -218
  553. mindspore/rewrite/parsers/constant_parser.py +9 -6
  554. mindspore/rewrite/parsers/container_parser.py +9 -7
  555. mindspore/rewrite/parsers/for_parser.py +36 -15
  556. mindspore/rewrite/parsers/function_def_parser.py +23 -20
  557. mindspore/rewrite/parsers/if_parser.py +28 -24
  558. mindspore/rewrite/parsers/module_parser.py +202 -25
  559. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  560. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  561. mindspore/rewrite/parsers/return_parser.py +6 -6
  562. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  563. mindspore/rewrite/sparsify/sparsify.py +4 -1
  564. mindspore/rewrite/sparsify/utils.py +11 -5
  565. mindspore/rewrite/symbol_tree.py +577 -732
  566. mindspore/rewrite/symbol_tree_builder.py +9 -175
  567. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  568. mindspore/run_check/_check_version.py +46 -39
  569. mindspore/run_check/run_check.py +3 -2
  570. mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
  571. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  572. mindspore/scipy/__init__.py +1 -1
  573. mindspore/scipy/linalg.py +67 -61
  574. mindspore/scipy/ops.py +5 -41
  575. mindspore/scipy/ops_grad.py +3 -2
  576. mindspore/scipy/ops_wrapper.py +5 -5
  577. mindspore/scipy/optimize/line_search.py +8 -8
  578. mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
  579. mindspore/scipy/optimize/minimize.py +16 -12
  580. mindspore/scipy/utils.py +1 -52
  581. mindspore/scipy/utils_const.py +4 -4
  582. mindspore/train/__init__.py +4 -4
  583. mindspore/train/_utils.py +13 -5
  584. mindspore/train/amp.py +410 -148
  585. mindspore/train/anf_ir_pb2.py +16 -4
  586. mindspore/train/callback/_backup_and_restore.py +8 -11
  587. mindspore/train/callback/_callback.py +80 -3
  588. mindspore/train/callback/_checkpoint.py +82 -51
  589. mindspore/train/callback/_early_stop.py +12 -15
  590. mindspore/train/callback/_history.py +1 -1
  591. mindspore/train/callback/_lambda_callback.py +13 -13
  592. mindspore/train/callback/_landscape.py +21 -17
  593. mindspore/train/callback/_loss_monitor.py +9 -10
  594. mindspore/train/callback/_on_request_exit.py +16 -33
  595. mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
  596. mindspore/train/callback/_summary_collector.py +44 -30
  597. mindspore/train/callback/_time_monitor.py +62 -12
  598. mindspore/train/data_sink.py +10 -16
  599. mindspore/train/dataset_helper.py +154 -86
  600. mindspore/train/loss_scale_manager.py +14 -9
  601. mindspore/train/metrics/__init__.py +10 -2
  602. mindspore/train/metrics/accuracy.py +1 -1
  603. mindspore/train/metrics/auc.py +1 -1
  604. mindspore/train/metrics/bleu_score.py +2 -2
  605. mindspore/train/metrics/confusion_matrix.py +14 -14
  606. mindspore/train/metrics/cosine_similarity.py +3 -3
  607. mindspore/train/metrics/dice.py +1 -1
  608. mindspore/train/metrics/fbeta.py +1 -1
  609. mindspore/train/metrics/hausdorff_distance.py +8 -6
  610. mindspore/train/metrics/mean_surface_distance.py +5 -4
  611. mindspore/train/metrics/metric.py +49 -17
  612. mindspore/train/metrics/occlusion_sensitivity.py +4 -4
  613. mindspore/train/metrics/perplexity.py +1 -1
  614. mindspore/train/metrics/precision.py +2 -2
  615. mindspore/train/metrics/recall.py +2 -3
  616. mindspore/train/metrics/roc.py +7 -7
  617. mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
  618. mindspore/train/metrics/topk.py +7 -4
  619. mindspore/train/mind_ir_pb2.py +193 -48
  620. mindspore/train/model.py +377 -133
  621. mindspore/train/serialization.py +697 -245
  622. mindspore/train/summary/_summary_adapter.py +5 -2
  623. mindspore/train/summary/_writer_pool.py +4 -3
  624. mindspore/train/summary/summary_record.py +25 -23
  625. mindspore/train/train_thor/convert_utils.py +39 -23
  626. mindspore/train/train_thor/dataset_helper.py +4 -3
  627. mindspore/train/train_thor/model_thor.py +8 -8
  628. mindspore/version.py +1 -1
  629. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
  630. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
  631. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  632. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  633. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  634. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  635. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  636. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  637. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  638. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  639. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  640. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  641. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  642. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  643. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  644. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  645. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  646. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  647. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  648. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  649. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  650. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  651. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  652. mindspore/_extends/graph_kernel/expander.py +0 -80
  653. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
  654. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  655. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  656. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  657. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  658. mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
  659. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  660. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  661. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  662. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  663. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  664. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  665. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  666. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  667. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  668. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  669. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  670. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  671. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  672. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  673. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  674. mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
  675. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  676. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  677. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  678. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  679. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  680. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  681. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  682. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  683. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  684. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  685. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  686. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  687. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  688. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  689. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  690. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  691. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  692. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  693. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  694. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  695. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  696. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  697. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  698. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  699. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  700. mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
  701. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  702. mindspore/_extends/parse/jit_fallback_modules.py +0 -51
  703. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  704. mindspore/dataset/engine/graphdata.py +0 -1586
  705. mindspore/include/api/net.h +0 -142
  706. mindspore/ops/_grad/grad_array_ops.py +0 -1347
  707. mindspore/ops/_grad/grad_clip_ops.py +0 -84
  708. mindspore/ops/_grad/grad_debug_ops.py +0 -68
  709. mindspore/ops/_grad/grad_inner_ops.py +0 -235
  710. mindspore/ops/_grad/grad_math_ops.py +0 -1684
  711. mindspore/ops/_grad/grad_nn_ops.py +0 -1529
  712. mindspore/ops/_grad/grad_other_ops.py +0 -89
  713. mindspore/ops/_grad/grad_sequence_ops.py +0 -296
  714. mindspore/ops/_grad/grad_sparse.py +0 -323
  715. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
  716. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
  717. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  718. mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
  719. mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
  720. mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
  721. mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
  722. mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
  723. mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
  724. mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
  725. mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
  726. mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
  727. mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
  728. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
  729. mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
  730. mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
  731. mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
  732. mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
  733. mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
  734. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
  735. mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
  736. mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
  737. mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
  738. mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
  739. mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
  740. mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
  741. mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
  742. mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
  743. mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
  744. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
  745. mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
  746. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  747. mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
  748. mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
  749. mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
  750. mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
  751. mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
  752. mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
  753. mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
  754. mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
  755. mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
  756. mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
  757. mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
  758. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  759. mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
  760. mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
  761. mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
  762. mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
  763. mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
  764. mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
  765. mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
  766. mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
  767. mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
  768. mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
  769. mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
  770. mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
  771. mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
  772. mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
  773. mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
  774. mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
  775. mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
  776. mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
  777. mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
  778. mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
  779. mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
  780. mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
  781. mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
  782. mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
  783. mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
  784. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
  785. mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
  786. mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
  787. mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
  788. mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
  789. mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
  790. mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
  791. mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
  792. mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
  793. mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
  794. mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
  795. mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
  796. mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
  797. mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
  798. mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
  799. mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
  800. mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
  801. mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
  802. mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
  803. mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
  804. mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
  805. mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
  806. mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
  807. mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
  808. mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
  809. mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
  810. mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
  811. mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
  812. mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
  813. mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
  814. mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
  815. mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
  816. mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
  817. mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
  818. mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
  819. mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
  820. mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
  821. mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
  822. mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
  823. mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
  824. mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
  825. mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
  826. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
  827. mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
  828. mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
  829. mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
  830. mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
  831. mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
  832. mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
  833. mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
  834. mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
  835. mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
  836. mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
  837. mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  838. mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
  839. mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
  840. mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
  841. mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
  842. mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
  843. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  844. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  845. mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
  846. mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
  847. mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
  848. mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
  849. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
  850. mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
  851. mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
  852. mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
  853. mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
  854. mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
  855. mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
  856. mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
  857. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
  858. mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
  859. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
  860. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
  861. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
  862. mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
  863. mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
  864. mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
  865. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  866. mindspore/rewrite/node_visitor.py +0 -44
  867. mindspore/rewrite/topological_manager.py +0 -203
  868. mindspore/scipy/sparse/linalg.py +0 -192
  869. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  870. {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -14,30 +14,31 @@
14
14
  # ============================================================================
15
15
  """SymbolTree class define of Rewrite according to forward function of a network."""
16
16
  import stat
17
- from typing import Optional, Union, Tuple, Any
17
+ from typing import Optional, Union, Tuple, Any, Dict, List
18
18
  import os
19
19
  import sys
20
20
  import ast
21
- import importlib
22
- import types
21
+ import importlib.util
23
22
  import time
24
- import astunparse
25
23
 
26
24
  from mindspore.nn import Cell
27
25
  from mindspore import log as logger
28
- from mindspore.rewrite.ast_creator_register import ast_creator_registry
29
- from .node import Node, TreeNode
26
+ from .node.node import Node, TreeNode
30
27
  from .api.node_type import NodeType
31
- from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, CheckPropertyIsUsed
28
+ from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
32
29
  from .api.scoped_value import ScopedValue, ValueType
33
30
  from .symbol_tree_dumper import SymbolTreeDumper
34
- from .topological_manager import TopoManager
31
+ from .node.node_topological_manager import TopoManager
35
32
  from .namer import TargetNamer, NodeNamer, ClassNamer
36
33
  from .common.observer import Observer
37
34
  from .common.observable import Observable
38
35
  from .common.event import Event
39
- from .node_visitor import NodeVisitor
36
+ from .node.node_manager import NodeManager
40
37
 
38
+ if sys.version_info >= (3, 9):
39
+ import ast as astunparse # pylint: disable=reimported, ungrouped-imports
40
+ else:
41
+ import astunparse
41
42
 
42
43
  class Position:
43
44
  """
@@ -81,6 +82,7 @@ class FieldFinder(AstFinder):
81
82
  Args:
82
83
  scope (ast.AST): An instance of ast node as search scope.
83
84
  """
85
+
84
86
  def __init__(self, scope: ast.AST):
85
87
  super().__init__(scope)
86
88
  self._result = False
@@ -134,7 +136,7 @@ class IfFixer(ast.NodeTransformer):
134
136
  self.generic_visit(node)
135
137
 
136
138
 
137
- class SymbolTree(Observer, Observable):
139
+ class SymbolTree(Observer, Observable, NodeManager):
138
140
  """
139
141
  A symbol-tree usually corresponding to forward method of a network.
140
142
 
@@ -147,227 +149,138 @@ class SymbolTree(Observer, Observable):
147
149
  """
148
150
 
149
151
  def __init__(self, origin_network: Cell, module_ast: ast.Module):
150
- super().__init__()
152
+ Observer.__init__(self)
151
153
  Observable.__init__(self)
152
- origin_network_key = "handler"
154
+ self._node_namer = NodeNamer()
155
+ self._node_namer.add_name('obj')
156
+ NodeManager.__init__(self, self._node_namer)
157
+ NodeManager.reg_observer(self, observer=self)
153
158
  # init unique-namers
154
159
  self._target_namer = TargetNamer()
155
- self._node_name_namer = NodeNamer()
156
- # name or node would use as name of field, so name of origin network handler field should be added into \
157
- # _node_name_namer.
158
- self._node_name_namer.add_name(origin_network_key)
159
- self._topo_mgr = TopoManager()
160
- self._topo_mgr.reg_observer(self)
161
-
162
- self._nodes: {str, Node} = {}
163
- # parameters of forward method
164
- self._inputs: [Node] = []
160
+ # input arguments of function
165
161
  self._ori_cls_name = type(origin_network).__name__
166
162
  self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
163
+ NodeManager.set_manager_name(self, self._opt_cls_name)
167
164
  self._origin_network = origin_network
168
165
  self._module_ast: ast.Module = module_ast
166
+ self._import_asts: Optional[ast.Ast] = []
169
167
  self._class_ast: Optional[ast.ClassDef] = None
170
168
  self._root_ast: Optional[ast.FunctionDef] = None
171
169
  self._init_func_ast: Optional[ast.FunctionDef] = None
172
170
  self._deleted_field = {}
173
171
  self._deleted_node = []
174
- self._external_func_ast = []
172
+ self._external_ast = []
175
173
  self._father_class_ast = []
176
-
177
- # head node is always point to the first node(in source code order) of SymbolTree
178
- self._head = None
179
- # tail node is always point to the last node(in source code order) of SymbolTree
180
- self._tail = None
181
- self._return: Optional[Node] = None
182
-
183
174
  self._modified = False
184
- self._node_visitor = None
185
-
186
175
  self._tmp_file_limits = 20
187
176
  self._tmp_files = []
188
177
  self._saved_file_name = "./network_define.py"
178
+ # used to insert "sys.path.append(xxx)"
179
+ self._net_file_paths = []
180
+ self._tmp_import_strs = []
181
+ self._tmp_unmodified_strees: {type, str} = {}
182
+ self._tmp_replacers = []
183
+ # Record imported modules and names of each files
184
+ # The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
185
+ # Format: {file_path: {module: [name, ...], ...}, ...}
186
+ self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
189
187
 
190
188
  def __del__(self):
191
189
  for tmp_file in self._tmp_files:
192
190
  tmp_file.close()
193
191
 
194
192
  @staticmethod
195
- def _find_consumers_and_providers(nodes: [Node]):
196
- """
197
- Find consumers and providers for all nodes according to their targets and arguments.
198
- """
199
- consumers: {ScopedValue: [Node]} = {}
200
- providers: {ScopedValue: Node} = {}
201
- for node in nodes:
202
- for arg in node.get_args():
203
- if consumers.get(arg):
204
- consumers[arg].append(node)
205
- else:
206
- consumers[arg] = [node]
207
- for _, arg in node.get_kwargs():
208
- if consumers.get(arg):
209
- consumers[arg].append(node)
210
- else:
211
- consumers[arg] = [node]
212
- for target in node.get_targets():
213
- if providers.get(target) is not None:
214
- raise RuntimeError(f"Target({target}) of node duplicated")
215
- providers[target] = node
216
- return consumers, providers
217
-
218
- @staticmethod
219
- def _link_nodes_and_find_root(nodes: [Node]) -> Node:
220
- """
221
- Find inputs for all nodes created by Replacement according to their targets and arguments.
222
-
223
- Find root node of all nodes created by Replacement. One and Only one root should be found.
224
-
225
- Args:
226
- nodes (list[Node]): A list of instance of Node created by Replacement.
227
-
228
- Returns:
229
- An instance of Node represents root of input nodes.
230
- """
231
- consumers, providers = SymbolTree._find_consumers_and_providers(nodes)
232
- # find root node
233
- root = None
234
- for node in nodes:
235
- used = 0
236
- for target in node.get_targets():
237
- cur_consumers = consumers.get(target)
238
- if not cur_consumers:
239
- continue
240
- for cur_consumer in cur_consumers:
241
- if id(cur_consumer) != id(node):
242
- used += 1
243
- break
244
- if used == 0:
245
- if root is not None:
246
- raise RuntimeError("Replacement should only has one root, found multi-root")
247
- root = node
248
- if root is None:
249
- raise RuntimeError("Replacement should only has one root, found no root")
250
- # link node's input
251
- for node in nodes:
252
- inputs = []
253
- for _, arg in node.get_normalized_args().items():
254
- node_input: Node = providers.get(arg)
255
- if id(node_input) != id(node):
256
- inputs.append(node_input)
257
- node.set_inputs(inputs)
258
- return root
259
-
260
- @staticmethod
261
- def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
262
- """Find all non-duplicated class name of SymbolTree recursively."""
263
- replacer = AstReplacer(stree._class_ast)
264
- replacers.append(replacer)
265
- for node in stree.nodes():
266
- if not isinstance(node, TreeNode):
193
+ def _remove_unused_import(module_ast):
194
+ """remove unused import in self._module_ast"""
195
+ str_checker = StrChecker(module_ast)
196
+ for i in range(len(module_ast.body) - 1, -1, -1):
197
+ body = module_ast.body[i]
198
+ if not isinstance(body, (ast.Import, ast.ImportFrom)):
267
199
  continue
268
- if node.symbol_tree._class_ast is None:
200
+ if isinstance(body, ast.Import):
269
201
  continue
270
- sub_stree: SymbolTree = node.symbol_tree
271
- SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
272
- # all modified ast.ClassDef should export to code
273
- if sub_stree._modified:
274
- allow_class_name.append(sub_stree._class_ast.name)
202
+ if isinstance(body, ast.ImportFrom) and body.module == "cell":
203
+ module_ast.body.remove(body)
275
204
  continue
276
- # all un-modified ast.ClassDef only keep one instance
277
- seen_cls_name = seen_class.get(type(sub_stree.get_origin_network()))
278
- if seen_cls_name is not None:
279
- replacer.replace_all(sub_stree._class_ast.name, seen_cls_name)
280
- else:
281
- seen_class[type(sub_stree.get_origin_network())] = sub_stree._class_ast.name
282
- allow_class_name.append(sub_stree._class_ast.name)
205
+ for alias in body.names:
206
+ name = alias.asname if alias.asname else alias.name
207
+ if not str_checker.check(name):
208
+ if len(body.names) == 1:
209
+ module_ast.body.remove(body)
210
+ i += 1
211
+ else:
212
+ body.names.remove(alias)
213
+
214
+ @staticmethod
215
+ def _remove_duplicated_import(module_ast):
216
+ """Remove duplicated import of 'net'."""
217
+ imports = set()
218
+ futures = set()
219
+ classes = set()
220
+
221
+ class TransImportNode(ast.NodeTransformer):
222
+ """Find all import nodes from input ast node."""
223
+
224
+ def visit_ClassDef(self, node: ast.ClassDef) -> Any:
225
+ class_str = astunparse.unparse(node)
226
+ if class_str not in classes:
227
+ classes.add(node.name)
228
+ return node
229
+ return
230
+
231
+ def visit_Try(self, node: ast.Try) -> Any:
232
+ if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
233
+ import_str = astunparse.unparse(node)
234
+ if import_str not in imports:
235
+ imports.add(import_str)
236
+ return node
237
+ return
238
+
239
+ def visit_Import(self, node: ast.Import) -> Any:
240
+ import_str = astunparse.unparse(node)
241
+ if import_str not in imports:
242
+ imports.add(import_str)
243
+ return node
244
+ return
245
+
246
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
247
+ """
248
+ Once the father class 'A' is defined in the current module, all the next imported class 'A' should
249
+ be removed. e.g.
250
+ def class A():
251
+ ...
252
+ from xxx import A, B
253
+ =>
254
+ def class A():
255
+ ...
256
+ from xxx import B
257
+ """
258
+ import_str = astunparse.unparse(node)
259
+
260
+ if import_str not in imports:
261
+ imports.add(import_str)
262
+ # remove "__future__" module
263
+ if node.module == '__future__':
264
+ futures.add(node.module)
265
+ return
266
+ # remove modules which have been defined in the code file
267
+ # it occurs when class A is a father class and other sub-classes import A
268
+ for alias in node.names[:]:
269
+ if alias.name in classes:
270
+ node.names.remove(alias)
271
+ # if the alias(es) in node.names are all removed, this import statement should be removed
272
+ if not node.names:
273
+ return
274
+ return node
275
+ return
276
+
277
+ get_node_handler = TransImportNode()
278
+ get_node_handler.generic_visit(module_ast)
283
279
 
284
280
  def finish_build(self):
285
281
  """Add Event.TopologicalChangeEvent event when build is finished."""
286
282
  self.add_event(Event.TopologicalChangeEvent)
287
283
 
288
- def _create_call_function(self, func, targets, args, kwargs):
289
- """
290
- Create a Node object and generate the execution code to insert into the source code.
291
- The source code calls the 'func' function with 'args' and' kwargs' as parameters.
292
-
293
- Args:
294
- func (FunctionType) - The function to be called.
295
- targets (list [str]) - indicates the output name. As the output of the node in the source code.
296
- args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
297
- code. The default value is None, which means there is no parameter input in the cell.
298
- kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
299
- input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
300
- code as the 'kwargs' in the statement expression. The default value is None, which means there is no
301
- 'kwargs' input.
302
-
303
- Returns:
304
- An instance of `Node`.
305
- """
306
- if not isinstance(func, types.FunctionType):
307
- raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
308
-
309
- _package = func.__globals__['__package__']
310
- func_name = ".".join([_package, func.__name__]) if _package else func.__name__
311
-
312
- ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
313
- scope_targets = [ScopedValue.create_naming_value(targets[0])]
314
- scope_func = ScopedValue.create_naming_value(func_name, "")
315
- call_args = list()
316
- for arg in args:
317
- if isinstance(arg, Node):
318
- call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
319
- else:
320
- call_args.append(ScopedValue.create_variable_value(arg))
321
- call_kwargs = {}
322
- for k, v in kwargs.items():
323
- call_kwargs[k] = ScopedValue.create_variable_value(v)
324
- node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
325
- call_kwargs)
326
- return node
327
-
328
- def create_assign_node(self, targets, func_name, args, kwargs):
329
- """
330
- Create a ast.Assign type node.
331
-
332
- Args:
333
- targets (list): _description_
334
- func_name (_type_): _description_
335
- args (_type_): _description_
336
- kwargs (_type_): _description_
337
-
338
- Returns:
339
- _type_: _description_
340
- """
341
- # create targets
342
- ast_targets = [ast_creator_registry.get("Name")(targets)]
343
- # create call
344
- ast_func = ast_creator_registry.get("Attribute")(func_name)
345
- ast_args = ast_creator_registry.get("Args")(args)
346
- ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
347
- ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
348
- # create assign
349
- ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
350
- return ast_node
351
-
352
- def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
353
- '''
354
- Instantiate an instance of node whose type is `CallFunction`.
355
-
356
- Args:
357
- node_name (str): Name of node.
358
- func_name (str): Name of function.
359
- ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
360
- targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
361
- func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
362
- args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
363
- kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
364
- class.
365
- '''
366
- logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
367
- node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
368
- node.set_belong_symbol_tree(self)
369
- return node
370
-
371
284
  def get_ori_cls_name(self) -> str:
372
285
  """
373
286
  Get class name of original network.
@@ -423,6 +336,7 @@ class SymbolTree(Observer, Observable):
423
336
  corresponding network class.
424
337
  """
425
338
  self._root_ast = ast_node
339
+ NodeManager.set_ast_functiondef(self, ast_node)
426
340
 
427
341
  def get_class_ast(self):
428
342
  """
@@ -461,18 +375,6 @@ class SymbolTree(Observer, Observable):
461
375
  """
462
376
  self._init_func_ast = ast_node
463
377
 
464
- def get_inputs(self):
465
- return self._inputs
466
-
467
- def get_head_node(self):
468
- """
469
- Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
470
-
471
- Returns:
472
- An instance of node.
473
- """
474
- return self._head
475
-
476
378
  def get_origin_network(self):
477
379
  """
478
380
  Getter of `_origin_network`.
@@ -486,36 +388,53 @@ class SymbolTree(Observer, Observable):
486
388
  """Get dict of nodes"""
487
389
  return self._nodes
488
390
 
489
- def nodes(self):
391
+ def get_node_namer(self):
392
+ """Get _node_namer"""
393
+ return self._node_namer
394
+
395
+ def is_modified(self):
490
396
  """
491
- Get generator of nodes of current `SymbolTree`.
397
+ Check whether symbol tree is modified.
492
398
 
493
- Returns:
494
- A generator for iterating Nodes of `SymbolTree`.
399
+ Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
400
+ the symbol tree is created.
495
401
  """
496
- if self._node_visitor is None:
497
- self._node_visitor = NodeVisitor(self)
498
- it = iter(self._node_visitor)
402
+ return self._modified
499
403
 
500
- while True:
501
- try:
502
- n = next(it)
503
- except StopIteration:
504
- return None
505
- yield n
404
+ def set_modified_true(self):
405
+ """
406
+ Set self._modified true.
506
407
 
507
- def get_node(self, node_name: str) -> Optional[Node]:
408
+ Self._modified is set true when 'if' exists in the original network.
409
+ In this situation, different original network instance tends to be different.
410
+ Hence, the class name should be updated.
508
411
  """
509
- Get node of current symbol_tree by `node_name`.
412
+ self._modified = True
510
413
 
511
- Args:
512
- node_name (str): A str represents name of node as key of query.
414
+ def get_import_asts(self):
415
+ """Get _import_asts"""
416
+ return self._import_asts
513
417
 
514
- Returns:
515
- An instance of Node if found else None.
516
- """
418
+ def get_external_ast(self):
419
+ """Get _external_ast"""
420
+ return self._external_ast
421
+
422
+ def get_father_class_ast(self):
423
+ """Get _father_class_ast"""
424
+ return self._father_class_ast
517
425
 
518
- return self._nodes.get(node_name)
426
+ def get_imported_modules(self, file_path: str):
427
+ """Get all modules and module_paths in file of `file_path` ."""
428
+ return self._imported_modules.get(file_path, {})
429
+
430
+ def save_imported_modules(self, file_path: str, module: str, names: List[str]):
431
+ """Save module and names into _imported_modules."""
432
+ imported_modules = self.get_imported_modules(file_path)
433
+ if imported_modules.get(module):
434
+ imported_modules[module].extend(names)
435
+ else:
436
+ imported_modules[module] = names
437
+ self._imported_modules[file_path] = imported_modules
519
438
 
520
439
  def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
521
440
  """
@@ -553,7 +472,7 @@ class SymbolTree(Observer, Observable):
553
472
  return []
554
473
  if real_node.get_node_type() == NodeType.Output:
555
474
  return []
556
- return self._topo_mgr.get_node_users(node_or_name)
475
+ return TopoManager.get_node_users(real_node)
557
476
 
558
477
  def before(self, node_or_name: Union[Node, str]) -> Position:
559
478
  """
@@ -606,9 +525,11 @@ class SymbolTree(Observer, Observable):
606
525
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
607
526
  return Position.create(node.get_belong_symbol_tree(), node, False)
608
527
 
609
- def insert_node(self, position: Optional[Position], node: Node, insert_to_ast: bool = True) -> Node:
528
+ def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
529
+ insert_to_ast: bool = True):
610
530
  """
611
- Insert a node into SymbolTree.
531
+ Insert a node before or after base_node.
532
+
612
533
  Note:
613
534
  Name of node will be unique while inserting node into SymbolTree.
614
535
 
@@ -627,52 +548,73 @@ class SymbolTree(Observer, Observable):
627
548
  Topological relation is updated and inputs of corresponding node is updated.
628
549
 
629
550
  Args:
630
- position (Position): A Position indicates an insert position point.
631
- node (Node): An instance of node to be inserted in.
632
- insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
633
- True.
551
+ new_node (Node): Node to be inserted.
552
+ base_node (Node): New node will be inserted before or after base_node.
553
+ before_node (bool): Indicate whether new node is inserted before base_node.
554
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
555
+ NodeManager of symboltree's construct function.
556
+ insert_to_ast (bool): Indicate whether ast nodes need to be updated.
634
557
 
635
558
  Returns:
636
559
  An instance of node which has been inserted into SymbolTree.
637
560
 
638
561
  Raises:
639
- RuntimeError: If 'position' is not in current SymbolTree.
562
+ ValueError: Node in the SymbolTree is inserted into SymbolTree again.
640
563
  RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
641
564
  """
642
- if position is not None and hasattr(position.node, "container"):
643
- cellcontainer = getattr(position.node, "container")
644
- index = cellcontainer.node_list.index(position.node)
645
- index = index if position.before_node else index + 1
646
- cellcontainer.insert(index, node)
647
- return node
648
- # if position in current SymbolTree
649
- if position is not None and position.symbol_tree is not self:
650
- raise RuntimeError("Position is not in current SymbolTree:", position)
651
- if position is not None and position.node.get_node_type() == NodeType.Input:
565
+ if new_node.get_belong_symbol_tree():
566
+ raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
567
+
568
+ # Check if base_node in current SymbolTree
569
+ if base_node is not None:
570
+ stree = base_node.get_belong_symbol_tree()
571
+ if stree is not None and stree is not self:
572
+ raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
573
+ f"current: {self.get_ori_cls_name()}.")
574
+
575
+ # Check if node is inserted between Input node
576
+ if base_node is not None and base_node.get_node_type() == NodeType.Input:
652
577
  valid = True
653
- if position.before_node:
578
+ if before_node:
654
579
  valid = False
655
- if position.node.get_next() is not None and position.node.get_next().get_node_type() == NodeType.Input:
580
+ if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
656
581
  valid = False
657
582
  if not valid:
658
- raise RuntimeError("Can not insert a node before or between parameters:", position)
659
- # unique targets, name while insert node into symbol_tree
660
- node_name = self._node_name_namer.get_name(node)
661
- node.set_name(node_name)
662
- self._handle_custom_obj_in_normalized_args(node)
663
- # _unique_targets must called after _update_args_for_unique and _update_kwargs_for_unique
664
- self._unique_targets(node)
665
- self._insert_node(position, node)
666
- if isinstance(node, TreeNode):
667
- node.symbol_tree.reg_observer(self)
668
- if self._node_visitor:
669
- self._node_visitor.append_node(node)
670
- # update init-function-ast and construct-function-ast
671
- if insert_to_ast:
672
- self._insert_to_ast_while_insert_node(node, position)
673
- return node
583
+ raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
584
+
585
+ # save target name, which is used to provide unique target
586
+ if new_node.get_targets():
587
+ for target in new_node.get_targets():
588
+ self._target_namer.add_name(str(target))
589
+
590
+ self._handle_custom_obj_in_normalized_args(new_node)
591
+
592
+ # Insert node into NodeManager
593
+ if node_manager is None:
594
+ if base_node is None:
595
+ raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
596
+ node_manager = base_node.get_node_manager()
597
+
598
+ # set node's _belong_symbol_tree
599
+ new_node.set_belong_symbol_tree(self)
600
+
601
+ if node_manager is self:
602
+ NodeManager.insert_node(self, new_node, base_node, before_node)
603
+ if insert_to_ast:
604
+ # update init-function-ast and construct-function-ast
605
+ self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
606
+ else:
607
+ node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
674
608
 
675
- def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
609
+ # register code changed event observer, which is used to update _modified flag.
610
+ if new_node.get_node_type() == NodeType.Tree:
611
+ new_node.symbol_tree.reg_observer(self)
612
+ elif isinstance(new_node, NodeManager):
613
+ new_node.reg_observer(self)
614
+
615
+ return new_node
616
+
617
+ def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
676
618
  """
677
619
  Append a node to SymbolTree.
678
620
 
@@ -680,13 +622,17 @@ class SymbolTree(Observer, Observable):
680
622
  node (Node): An instance of node to be appended.
681
623
  append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
682
624
  True.
625
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
626
+ NodeManager of symboltree's construct function.
683
627
 
684
628
  Returns:
685
629
  An instance of node which has been appended to SymbolTree.
686
630
  """
687
- return self.insert_node(Position.create(self, self._tail, False), node, append_to_ast)
631
+ if node_manager is None:
632
+ node_manager = self
633
+ return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
688
634
 
689
- def append_origin_field(self, node: Node) -> Node:
635
+ def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
690
636
  """
691
637
  Append an original field node to SymbolTree. An original field node represents a node created from existing
692
638
  statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
@@ -695,18 +641,16 @@ class SymbolTree(Observer, Observable):
695
641
 
696
642
  Args:
697
643
  node (Node): An instance of node to be appended.
644
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
645
+ NodeManager of symboltree's construct function.
698
646
 
699
647
  Returns:
700
648
  An instance of node which has been appended to SymbolTree.
701
649
  """
702
- self._update_args_kwargs_for_unique(node)
703
- if node.get_node_type() == NodeType.Output:
704
- self._return = node
705
- elif node.get_node_type() == NodeType.Input:
706
- self._inputs.append(node)
707
- return self.append_node(node, False)
650
+ return self.append_node(node, node_manager, False)
708
651
 
709
- def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None):
652
+ def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
653
+ node_manager: NodeManager = None):
710
654
  """
711
655
  Append an input node to SymbolTree corresponding to parameter of forward method of network class.
712
656
  This method is called while building SymbolTree usually.
@@ -716,13 +660,18 @@ class SymbolTree(Observer, Observable):
716
660
  param_name (str): A str represents name of parameter of forward method of network class.
717
661
  default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
718
662
  means parameter has no default value.
663
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
664
+ NodeManager of symboltree's construct function.
719
665
 
720
666
  Returns:
721
667
  An instance of input node which has been appended to SymbolTree.
722
668
  """
723
669
  if param_name == "self":
724
670
  return
725
- for input_node in self._inputs:
671
+ # check param_name duplicated
672
+ if node_manager is None:
673
+ node_manager = self
674
+ for input_node in node_manager._inputs:
726
675
  targets = input_node.get_targets()
727
676
  if len(targets) != 1:
728
677
  raise RuntimeError("targets should have 1 elements")
@@ -735,9 +684,10 @@ class SymbolTree(Observer, Observable):
735
684
  if exist_param == param_name:
736
685
  raise RuntimeError("input duplicated:", param_name)
737
686
  input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
738
- self.append_origin_field(input_node)
687
+ self.append_origin_field(input_node, node_manager)
739
688
 
740
- def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Optional[Node]:
689
+ def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
690
+ node_manager: NodeManager = None) -> Optional[Node]:
741
691
  """
742
692
  Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
743
693
  a list or a dict.
@@ -746,6 +696,8 @@ class SymbolTree(Observer, Observable):
746
696
  Args:
747
697
  ast_scope (ast.AST): A ast node represents ast node of scope of node.
748
698
  ast_node (ast.AST): A ast node represents ast node.
699
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
700
+ NodeManager of symboltree's construct function.
749
701
 
750
702
  Returns:
751
703
  An instance of python node if a new node has been appended to SymbolTree else None.
@@ -754,9 +706,9 @@ class SymbolTree(Observer, Observable):
754
706
  return None
755
707
  if isinstance(ast_node, (list, dict)) and not ast_node:
756
708
  return None
757
- return self.append_python_node(ast_scope, ast_node)
709
+ return self.append_python_node(ast_scope, ast_node, node_manager)
758
710
 
759
- def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
711
+ def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
760
712
  """
761
713
  Append a python node to SymbolTree.
762
714
  This method is called while building SymbolTree usually.
@@ -764,40 +716,50 @@ class SymbolTree(Observer, Observable):
764
716
  Args:
765
717
  ast_scope (ast.AST): A ast node represents ast node of scope of node.
766
718
  ast_node (ast.AST): A ast node represents ast node.
719
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
720
+ NodeManager of symboltree's construct function.
767
721
 
768
722
  Returns:
769
723
  An instance of python node which has been appended to SymbolTree.
770
724
  """
771
725
  logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
772
- node_name = self._node_name_namer.get_name(type(ast_node).__name__)
773
- self._update_names_for_unique(ast_node)
726
+ node_name = type(ast_node).__name__
774
727
  node = Node.create_python_node(ast_node, node_name)
775
- self._insert_node(Position.create(self, self._tail, False), node)
728
+ if node_manager is None or node_manager is self:
729
+ NodeManager.append_python_node(self, node)
730
+ else:
731
+ node_manager.append_python_node(node)
776
732
  return node
777
733
 
778
- def set_output(self, return_value: str, index: int) -> Node:
734
+ def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
735
+ node_manager: NodeManager = None) -> Node:
779
736
  """
780
737
  Update return value of return of forward method of network class.
781
738
 
782
739
  Args:
783
740
  return_value (str): A str represents new return value.
784
- index (int): A int indicates which return value to be updated.
741
+ arg_index (int): A int indicates which value in return to be updated.
742
+ return_idx (int): A int indicates which return node to be updated. Default: 0.
743
+ node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
744
+ symboltree's construct function.
785
745
 
786
746
  Returns:
787
747
  An instance of node represents return node after updated.
788
748
  """
789
- if self._return is None:
790
- raise RuntimeError("SymbolTree has no output")
791
- self.set_node_arg(self._return, index, return_value)
792
- return self._return
749
+ node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
750
+ if not node_returns:
751
+ raise RuntimeError("Current node_manager has no output")
752
+ if return_idx >= len(node_returns):
753
+ raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
754
+ node_return = node_returns[return_idx]
755
+ self.set_node_arg(node_return, arg_index, return_value)
756
+ return node_return
793
757
 
794
758
  def erase_node(self, node_or_name: Union[Node, str]) -> Node:
795
759
  """
796
760
  Erase a node from SymbolTree.
797
- Note:
798
- If node is depended on by other node, RuntimeError will raise.
799
761
 
800
- Topological relation is updated.
762
+ Topological relation will be updated.
801
763
 
802
764
  Args:
803
765
  node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
@@ -813,71 +775,51 @@ class SymbolTree(Observer, Observable):
813
775
  node = self._get_real_node(node_or_name)
814
776
  if node is None:
815
777
  raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
816
- if hasattr(node, "container"):
817
- cellcontainer = getattr(node, "container")
818
- cellcontainer.erase(node)
819
- return node
820
- ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
821
- if not ret:
822
- raise RuntimeError("node not in function ast tree.")
823
- for key, value in self._nodes.items():
824
- if id(value) == id(node):
825
- self._nodes.pop(key)
826
- value.isolate()
827
- break
828
- self._topo_mgr.on_erase_node(node)
778
+ # erase node in NodeManager
779
+ node_manager = node.get_node_manager()
780
+
781
+ logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
782
+ f"node_manager: {node_manager.get_manager_name()}, "
783
+ f"code: {astunparse.unparse(node.get_ast()).strip()}, "
784
+ f"node_name:{node.get_name()}")
785
+
786
+ if node_manager is self:
787
+ NodeManager.erase_node(self, node)
788
+ ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
789
+ if not ret:
790
+ raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
791
+ else:
792
+ node_manager.erase_node(node)
829
793
  self._deleted_node.append(node.get_name())
830
794
  return node
831
795
 
832
796
  def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
833
797
  """
834
- Replace an old_node with a node_tree. 'new_node' is the root node of the node_tree.
835
- Note:
836
- Rewrite will iterate all nodes linked to this root node and insert these nodes into symbol_tree.
837
-
838
- Inputs of intra sub-tree nodes need to be welly set.
839
-
840
- Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
798
+ Replace an old_node with a node list.
841
799
 
842
800
  Args:
843
801
  old_node (Node): Node to be replaced.
844
- new_nodes (list[Node]): Node tree to replace in.
802
+ new_nodes (list[Node]): Node list to replace in.
845
803
 
846
804
  Returns:
847
- An instance of Node represents root of node_tree been replaced in.
805
+ Last node in new_nodes list.
848
806
 
849
807
  Raises:
850
808
  RuntimeError: If 'old_node' is isolated.
851
809
  RuntimeError: If 'old_node' is not belong to current SymbolTree.
852
810
  """
853
-
854
- if hasattr(old_node, "container"):
855
- self._replace_container_node(old_node, new_nodes)
856
- return new_nodes[0]
857
811
  real_old_node = self._get_real_node(old_node)
858
812
  if real_old_node is None:
859
813
  raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
860
- # get position
861
- next_node: Node = old_node.get_next()
862
- prev_node: Node = old_node.get_prev()
863
- if prev_node is None and next_node is None:
864
- raise RuntimeError("Try replacing a isolated node: ", old_node)
865
- if prev_node is None:
866
- position = self.before(next_node)
867
- else:
868
- position = self.after(prev_node)
869
- # insert node first, because targets of new_node is determined after insert
870
- new_tree_root = SymbolTree._link_nodes_and_find_root(new_nodes)
871
- new_node = self._insert_tree(position, new_tree_root)
872
- # use targets of insert tree to redirect edge
873
- users = self.get_node_users(old_node)
874
- if len(new_node.get_targets()) != 1:
875
- raise RuntimeError("targets of new_node should have 1 elements")
876
- for user in users:
877
- self.set_node_arg_by_node(user[0], user[1], new_node)
878
- # erase old_node after edge is redirected because node can be erased only when node is isolated topologically
814
+ # insert new_nodes into node_manager
815
+ node_manager = real_old_node.get_node_manager()
816
+ # insert new_nodes into NodeManager
817
+ base_node = old_node
818
+ for node in new_nodes:
819
+ self.insert_node(node, base_node, False, node_manager, True)
820
+ base_node = node
879
821
  self.erase_node(old_node)
880
- return new_node
822
+ return new_nodes[-1]
881
823
 
882
824
  def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
883
825
  """
@@ -933,30 +875,234 @@ class SymbolTree(Observer, Observable):
933
875
  if out_idx >= len(targets):
934
876
  raise RuntimeError("out_idx out of range: ", out_idx)
935
877
  new_arg = targets[out_idx]
936
- self.set_node_arg(real_dst_node, arg_idx, new_arg)
878
+ real_dst_node.set_arg(new_arg, arg_idx)
879
+ self._topo_mgr.on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
880
+
881
+ def unique_name(self, name: str):
882
+ """Get a unique name in the symboltree"""
883
+ return self._target_namer.get_name(name)
884
+
885
+ def unique_func_name(self, name: str):
886
+ """Get a unique function name in the symboltree"""
887
+ if not hasattr(self._origin_network, name):
888
+ return name
889
+ suffix = 1
890
+ while hasattr(self._origin_network, f"{name}_{suffix}"):
891
+ suffix += 1
892
+ return f"{name}_{suffix}"
893
+
894
+ def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
895
+ """
896
+ Set target of `node` .
897
+
898
+ Args:
899
+ node (Union[Node, str]): Node to be modified. Can be a node or name of node.
900
+ index (int): Indicate which target being modified.
901
+ arg (Union[ScopedValue, str]): New target to been set.
937
902
 
938
- def print_node_tabulate(self):
903
+ Raises:
904
+ ValueError: If `node` is not belong to current SymbolTree.
905
+ ValueError: If index of `node` 's target is greater than number of targets.
906
+ """
907
+
908
+ real_node = self._get_real_node(node)
909
+ if real_node is None:
910
+ raise ValueError("Node is not belong to current SymbolTree: ", node)
911
+ if isinstance(target, str):
912
+ target = ScopedValue.create_naming_value(target)
913
+ targets = node.get_targets()
914
+ if index >= len(targets):
915
+ raise ValueError(f"Index of node's target should be less than {len(targets)}, but got {index}")
916
+ old_target = targets[index]
917
+ targets[index] = target
918
+ node.set_targets(targets)
919
+ self._topo_mgr.on_update_target(node, index, old_target, target)
920
+
921
+ def all_nodes(self):
922
+ """
923
+ Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
924
+
925
+ Returns:
926
+ A list of nodes.
927
+ """
928
+ nodes = []
929
+ node_managers = [self]
930
+ while node_managers:
931
+ node_manager = node_managers.pop()
932
+ nodes.extend(node_manager.nodes())
933
+ for node in node_manager.nodes():
934
+ if isinstance(node, NodeManager):
935
+ node_managers.append(node)
936
+ for tree_node in self.get_tree_nodes():
937
+ stree = tree_node.symbol_tree
938
+ nodes.extend(stree.all_nodes())
939
+ return nodes
940
+
941
+ def get_node_from_name(self, node_name: str):
942
+ """
943
+ Get node from all NodeManagers in current symbol tree by `node_name`.
944
+
945
+ Args:
946
+ node_name (str): A str represents name of node as key of query.
947
+
948
+ Returns:
949
+ An instance of Node if found else None.
950
+ """
951
+ node_managers = [self]
952
+ while node_managers:
953
+ node_manager = node_managers.pop()
954
+ node = node_manager.get_node(node_name)
955
+ if node:
956
+ return node
957
+ for node in node_manager.nodes():
958
+ if isinstance(node, NodeManager):
959
+ node_managers.append(node)
960
+ return None
961
+
962
+ def print_node_tabulate(self, all_nodes: bool = False):
963
+ """
964
+ Print nodes information and nodes' topological relations.
965
+
966
+ Args:
967
+ all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
968
+ nodes, CellContainer nodes and sub symbol trees.
969
+ """
939
970
  try:
940
- from tabulate import tabulate
971
+ from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
941
972
  except ImportError:
942
- print("`print_tabular` relies on the library `tabulate`, "
943
- "which could not be found on this machine. Run `pip "
944
- "install tabulate` to install the library.")
945
- node_specs = [[n.get_node_type(), n.get_name(), n.get_targets(), n.get_args(), n.get_kwargs()]
946
- for n in self.nodes()]
947
- print(tabulate(node_specs, headers=['node type', 'name', 'target', 'args', 'kwargs']))
973
+ logger.warning("print_node_tabulate relies on the library `tabulate`, "
974
+ "which could not be found on this machine. Run `pip "
975
+ "install tabulate` to install the library.")
976
+ return ""
977
+ print(NodeManager.dump(self, self.get_manager_name()))
978
+ if all_nodes:
979
+ node_managers = [self]
980
+ while node_managers:
981
+ node_manager = node_managers.pop()
982
+ for node in node_manager.nodes():
983
+ if isinstance(node, NodeManager):
984
+ print(node.dump(node.get_manager_name()))
985
+ node_managers.append(node)
986
+ for tree_node in self.get_tree_nodes():
987
+ stree = tree_node.symbol_tree
988
+ stree.print_node_tabulate(all_nodes)
948
989
 
949
990
  def dump(self):
950
991
  """Dump graph."""
951
992
  dump_st = SymbolTreeDumper(self)
952
993
  dump_st.dump()
953
994
 
954
- def update_module_ast(self):
955
- for node in self._external_func_ast:
956
- self._module_ast.body.append(node)
957
- for node in self._father_class_ast:
958
- index = self._module_ast.body.index(self._class_ast)
959
- self._module_ast.body.insert(index, node)
995
+ def check_body_exist(self, body, code_bodies):
996
+ """Check whether body already exist in code_bodies"""
997
+ # Check import ast node exist by saving import code string to self._tmp_import_strs
998
+ if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
999
+ import_str = astunparse.unparse(body)
1000
+ if import_str in self._tmp_import_strs:
1001
+ return True
1002
+ self._tmp_import_strs.append(import_str)
1003
+ return False
1004
+
1005
+ # Check ClassDef ast node exist by using AstClassFinder
1006
+ if isinstance(body, ast.ClassDef):
1007
+ if sys.version_info >= (3, 9):
1008
+ class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
1009
+ else:
1010
+ class_finder = AstClassFinder(ast.Module(body=code_bodies))
1011
+ results = class_finder.find_all(body.name)
1012
+ return bool(results)
1013
+
1014
+ # Check FunctionDef ast node exist by using AstFunctionFinder
1015
+ if isinstance(body, ast.FunctionDef):
1016
+ if sys.version_info >= (3, 9):
1017
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
1018
+ else:
1019
+ function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
1020
+ results = function_finder.find_all(body.name)
1021
+ return bool(results)
1022
+
1023
+ return False
1024
+
1025
+ def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
1026
+ """
1027
+ For the unmodified symbol tree, only one definition code remains in the generated code.
1028
+ Everywhere else calling this symbol tree will use the class in this definition code.
1029
+ """
1030
+ # all modified ast.ClassDef will be exported to code
1031
+ if stree.is_modified():
1032
+ return False
1033
+ # all un-modified ast.ClassDef only keep one instance
1034
+ first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
1035
+ if first_cls_name is None:
1036
+ class_ast = stree.get_class_ast()
1037
+ if class_ast:
1038
+ self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
1039
+ return False
1040
+ # Un-modified ast.ClassDef already exist in code_bodies,
1041
+ # replace class name to class name of first un-modified ast.ClassDef.
1042
+ if sys.version_info >= (3, 9):
1043
+ replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
1044
+ else:
1045
+ replacer = AstReplacer(ast.Module(body=code_bodies))
1046
+ replacer.replace_all(stree.get_class_ast().name, first_cls_name)
1047
+ self._tmp_replacers.append(replacer)
1048
+ return True
1049
+
1050
+ def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
1051
+ """
1052
+ Convert nodes in stree to code_bodies
1053
+
1054
+ 1. Add import asts into code_bodies
1055
+ 2. Add class, function and other type of asts into code_bodies
1056
+ 3. Add father class asts into code_bodies
1057
+ 4. Add external function asts into code_bodies
1058
+ 5. Add subtrees to code_bodies
1059
+ 5.1 Add subtrees in construct to code_bodies
1060
+ 5.2 Add subtrees in CellContainers to code_bodies
1061
+
1062
+ """
1063
+ # Add import asts into code_bodies
1064
+ for body in stree.get_import_asts():
1065
+ if not self.check_body_exist(body, code_bodies):
1066
+ code_bodies.insert(insert_pos, body)
1067
+ insert_pos += 1
1068
+
1069
+ # Add class, function and other type of asts into code_bodies
1070
+ if stree.get_module_ast():
1071
+ for body in stree.get_module_ast().body:
1072
+ if self.check_body_exist(body, code_bodies):
1073
+ continue
1074
+ if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
1075
+ code_bodies.insert(insert_pos, body)
1076
+ else:
1077
+ code_bodies.append(body)
1078
+
1079
+ # Add father class asts into code_bodies
1080
+ for body in reversed(stree.get_father_class_ast()):
1081
+ if self.check_body_exist(body, code_bodies):
1082
+ # remove exist ast in old position, then insert ast to upper position
1083
+ if sys.version_info >= (3, 9):
1084
+ exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
1085
+ else:
1086
+ exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
1087
+ code_bodies.remove(exist_ast)
1088
+ code_bodies.insert(insert_pos, body)
1089
+
1090
+ # Add external asts into code_bodies
1091
+ for body in stree.get_external_ast():
1092
+ if not self.check_body_exist(body, code_bodies):
1093
+ code_bodies.insert(insert_pos, body)
1094
+ insert_pos += 1
1095
+
1096
+ # Add subtrees to code_bodies
1097
+ for node in stree.get_tree_nodes():
1098
+ sub_stree = node.symbol_tree
1099
+ # Ignore TreeNode create by function in the class
1100
+ if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
1101
+ continue
1102
+ # For the unmodified class, update class name to name of first class
1103
+ if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
1104
+ continue
1105
+ self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
960
1106
 
961
1107
  def get_code(self) -> str:
962
1108
  """
@@ -965,34 +1111,22 @@ class SymbolTree(Observer, Observable):
965
1111
  Returns:
966
1112
  A str represents source code of modified network.
967
1113
  """
968
- self._remove_unused_import()
969
- if self._init_func_ast:
970
- self._remove_unused_field()
971
- self._remove_duplicated_import()
972
- self.update_module_ast()
1114
+ self._tmp_import_strs.clear()
1115
+ self._tmp_unmodified_strees.clear()
1116
+ self._tmp_replacers.clear()
1117
+ code_bodies = []
1118
+ self.convert_stree_to_code_bodies(self, code_bodies)
1119
+ if sys.version_info >= (3, 9):
1120
+ gencode_module = ast.Module(body=code_bodies, type_ignores=[])
1121
+ else:
1122
+ gencode_module = ast.Module(body=code_bodies)
1123
+ SymbolTree._remove_unused_import(gencode_module)
1124
+ SymbolTree._remove_duplicated_import(gencode_module)
973
1125
  ast.fix_missing_locations(self._module_ast)
974
- # Find all ast.ClassDef which can be export to code
975
- # Replace duplicated ast.ClassDef reference in main-ClassDef
976
- seen_class: {type, str} = {}
977
- allow_class_name = [self._class_ast.name]
978
- replacers = []
979
- SymbolTree._find_all_class_in_symboltree(self, seen_class, allow_class_name, replacers)
980
- # Add all non-ClassDef body to gencode_module
981
- # Add all ClassDef in allow_class_name to gencode_module
982
- # Use gencode_module to generate code
983
- bodies = []
984
- for body in self._module_ast.body:
985
- if not isinstance(body, ast.ClassDef):
986
- bodies.append(body)
987
- continue
988
- if body.name in allow_class_name:
989
- bodies.append(body)
990
- gencode_module = ast.Module(body=bodies)
991
- if_fixer = IfFixer()
992
- if_fixer.fix(gencode_module)
1126
+ IfFixer().fix(gencode_module)
993
1127
  code = astunparse.unparse(gencode_module)
994
- # Restore main-ClassDef
995
- for replacer in replacers:
1128
+ # Revert the class name to its original state
1129
+ for replacer in self._tmp_replacers:
996
1130
  replacer.undo_all()
997
1131
  return code
998
1132
 
@@ -1026,305 +1160,71 @@ class SymbolTree(Observer, Observable):
1026
1160
  f.write(source.encode('utf-8'))
1027
1161
  f.flush()
1028
1162
 
1029
- def update_scope_for_unique(self, node: Union[ast.Attribute, ast.Call, ast.Subscript]):
1030
- """ Update scope of ast node because of unique-ing of targets of other nodes. """
1031
- if isinstance(node, ast.Call):
1032
- self.update_scope_for_unique(node.func)
1033
- return
1034
- if not isinstance(node, (ast.Attribute, ast.Subscript)):
1035
- logger.warning(f"Cannot update node {astunparse.unparse(node)} for unique, type of node should "
1036
- f"be one of (ast.Attribute, ast.Subscript).")
1037
- return
1038
- scope = node.value
1039
- if not isinstance(scope, ast.Name):
1040
- self.update_scope_for_unique(scope)
1041
- return
1042
- scope_name = scope.id
1043
- scope_name_unique = self._target_namer.get_real_arg(scope_name)
1044
- scope.id = scope_name_unique
1045
-
1046
- def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
1163
+ def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
1164
+ node_manager: NodeManager):
1047
1165
  """ insert_to_ast_while_insert_node. """
1048
- node.set_func(ScopedValue.create_naming_value(node.get_name(), "self"))
1049
- node_ast = node.get_ast()
1050
- if not isinstance(node_ast, ast.Assign):
1051
- raise RuntimeError("Only support insert cell op now")
1052
- if isinstance(node, TreeNode):
1053
- setattr(self._origin_network, node.get_name(), node.get_instance())
1054
- args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
1055
- [ScopedValue(ValueType.NamingValue, "", "obj"),
1056
- ScopedValue(ValueType.StringValue, "", node.get_name())])
1057
- value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
1058
- col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
1059
-
1060
- ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
1061
- assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
1062
- AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
1063
-
1064
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1065
- None if position is None else position.node.get_ast(),
1066
- position.before_node)
1067
- sub_stree: SymbolTree = node.symbol_tree
1068
- from .symbol_tree_builder import SymbolTreeBuilder
1069
- SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
1166
+ if new_node.get_node_type() == NodeType.Input:
1167
+ # insert a new input
1168
+ self._inputs.append(new_node)
1169
+ ast_construct = self.get_ast_root()
1170
+ arg: str = new_node.get_targets()[0].value
1171
+ ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
1172
+ AstModifier.append_arg_to_function(ast_construct, ast_arg)
1070
1173
  else:
1071
- AstModifier.insert_assign_to_function(self._init_func_ast,
1072
- targets=[ScopedValue(ValueType.NamingValue, "self", node.get_name())],
1073
- expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
1074
- args=[ScopedValue(ValueType.NamingValue, "", "obj"),
1075
- ScopedValue(ValueType.StringValue, "", node.get_name())])
1076
- AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
1077
- None if position is None else position.node.get_ast(),
1078
- position.before_node)
1079
- setattr(self._origin_network, node.get_name(), node.get_instance())
1080
-
1081
- def _remove_unused_import(self):
1082
- """remove unused import in self._module_ast"""
1083
- str_checker = StrChecker(self._module_ast)
1084
- for i in range(len(self._module_ast.body) - 1, -1, -1):
1085
- body = self._module_ast.body[i]
1086
- if not isinstance(body, (ast.Import, ast.ImportFrom)):
1087
- continue
1088
- if isinstance(body, ast.Import):
1089
- continue
1090
- if isinstance(body, ast.ImportFrom) and body.module == "cell":
1091
- self._module_ast.body.remove(body)
1092
- continue
1093
- for alias in body.names:
1094
- name = alias.asname if alias.asname else alias.name
1095
- if not str_checker.check(name):
1096
- if len(body.names) == 1:
1097
- self._module_ast.body.remove(body)
1098
- i += 1
1099
- else:
1100
- body.names.remove(alias)
1101
-
1102
- def _replace_container_node(self, old_node, new_nodes):
1103
- cellcontainer = getattr(old_node, "container")
1104
- index = cellcontainer.node_list.index(old_node)
1105
- for n in reversed(new_nodes):
1106
- cellcontainer.insert(index, n)
1107
- index = cellcontainer.node_list.index(old_node)
1108
- cellcontainer.erase(old_node)
1109
-
1110
- def _filter_out_to_delete_field(self, to_delete_field):
1111
- """filter out used field from `to_delete_field`"""
1112
- for func_def in self._class_ast.body:
1113
- if not isinstance(func_def, ast.FunctionDef):
1114
- continue
1115
- if func_def.name != "__init__":
1116
- to_delete_to_delete_keys = []
1117
- property_checker = CheckPropertyIsUsed(func_def)
1118
- for key, _ in self._deleted_field.items():
1119
- if property_checker.check("self", key):
1120
- to_delete_to_delete_keys.append(key)
1121
- property_checker = CheckPropertyIsUsed(func_def)
1122
- for key in to_delete_to_delete_keys:
1123
- self._deleted_field.pop(key)
1174
+ # insert a new assign statement
1175
+ ast_assign = new_node.get_ast()
1176
+ if ast_assign is None:
1177
+ func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
1178
+ new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
1179
+ ast_assign = new_node.update_ast_node()
1180
+ if not isinstance(ast_assign, ast.Assign):
1181
+ raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
1182
+ # Save instance into _origin_network.
1183
+ setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
1184
+ # Insert ast to __init__ function
1185
+ if isinstance(new_node, TreeNode):
1186
+ init_code = f"self.{new_node.get_name()} = " \
1187
+ f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
1124
1188
  else:
1125
- for body in func_def.body:
1126
- if not isinstance(body, ast.If):
1127
- continue
1128
- test = body.test
1129
- field_finder = FieldFinder(test)
1130
- to_delete_to_delete_keys = []
1131
- for key, _ in self._deleted_field.items():
1132
- if field_finder.check(key):
1133
- to_delete_to_delete_keys.append(key)
1134
- for key in to_delete_to_delete_keys:
1135
- self._deleted_field.pop(key)
1136
-
1137
- def _remove_unused_field(self):
1138
- """remove unused field in __init__ function"""
1139
- multi_targets = []
1140
- for index, body in enumerate(self._init_func_ast.body):
1141
- if not isinstance(body, ast.Assign):
1142
- continue
1143
- targets = body.targets
1144
- for target in targets:
1145
- if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
1146
- and target.value.id == "self":
1147
- self._deleted_field[target.attr] = index
1148
- if len(targets) > 1:
1149
- multi_targets.append(index)
1150
- self._filter_out_to_delete_field(self._deleted_field)
1151
- for i in range(len(self._init_func_ast.body) - 1, -1, -1):
1152
- if i in self._deleted_field.values():
1153
- if i in multi_targets:
1154
- raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
1155
- AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
1156
- ast.fix_missing_locations(self._init_func_ast)
1157
-
1158
- def _remove_duplicated_import(self):
1159
- """Remove duplicated import of 'net'."""
1160
- imports = []
1161
- for body in self._module_ast.body:
1162
- if isinstance(body, (ast.ImportFrom, ast.Import)):
1163
- import_str = astunparse.unparse(body)
1164
- if import_str not in imports:
1165
- imports.append(import_str)
1166
- else:
1167
- self._module_ast.body.remove(body)
1189
+ init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
1190
+ init_ast = ast.parse(init_code).body[0]
1191
+ AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
1192
+ # Insert ast to construct_function/class_internal_function
1193
+ ast_base_node = base_node.get_ast() if base_node else None
1194
+ ast_functiondef = node_manager.get_ast_functiondef()
1195
+ if not ast_functiondef:
1196
+ raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
1197
+ "when inserting the ast.")
1198
+ AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
1168
1199
 
1169
1200
  def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
1170
1201
  if isinstance(node_or_name, str):
1171
1202
  return self.get_node(node_or_name)
1172
1203
  return node_or_name
1173
1204
 
1174
- def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
1175
- """
1176
- Insert a node-tree into SymbolTree.
1177
- Note:
1178
- Inputs of intra sub-tree nodes need to be welly set.
1179
-
1180
- Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
1181
-
1182
- Args:
1183
- position (Position): A Position indicates an insert position point.
1184
- root (Node): An instance of node as root of node-tree to be inserted in.
1185
- insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
1186
- True.
1187
-
1188
- Returns:
1189
- An instance of node as root node of node-tree which has been inserted into SymbolTree.
1190
-
1191
- Raises:
1192
- RuntimeError: If 'position' is not in current SymbolTree.
1193
- """
1194
-
1195
- # if position not in current SymbolTree
1196
- if position.symbol_tree is not self:
1197
- raise RuntimeError("Position is not in current SymbolTree: ", position)
1198
-
1199
- queue: [Node] = [root]
1200
- todos: [] = []
1201
- inputs_list: [] = []
1202
- while queue:
1203
- cur_node = queue.pop(0)
1204
- if cur_node in todos:
1205
- continue
1206
- todos.append(cur_node)
1207
- node_inputs = cur_node.get_inputs()
1208
- inputs_list.append(node_inputs)
1209
- for node_input in node_inputs:
1210
- if node_input is not None:
1211
- queue.append(node_input)
1212
- todos.reverse()
1213
- inputs_list.reverse()
1214
- for index, todo in enumerate(todos):
1215
- self.insert_node(position, todo, insert_to_ast)
1216
- position = self.after(todo)
1217
- # relink input of node
1218
- original_inputs = inputs_list[index]
1219
- for arg_idx, original_input in enumerate(original_inputs):
1220
- if original_input is not None:
1221
- self.set_node_arg_by_node(todo, arg_idx, original_input)
1222
- return root
1223
-
1224
- def _unique_targets(self, node: Node):
1225
- """
1226
- Unique targets of node by _target_namer.
1227
-
1228
- Args:
1229
- node (Node): A Node whose targets to be uniqued.
1230
- """
1231
- new_targets: [ScopedValue] = []
1232
- if node.get_targets() is None:
1233
- return
1234
- for target in node.get_targets():
1235
- if not isinstance(target, ScopedValue):
1236
- raise TypeError("target should be ScopedValue, got: ", type(target))
1237
- unique_target = self._target_namer.get_name(target.value)
1238
- new_targets.append(ScopedValue.create_naming_value(unique_target, target.scope))
1239
- node.set_targets(new_targets)
1240
-
1241
- def _update_args_kwargs_for_unique(self, node: Node):
1242
- """
1243
- Update arguments and keyword arguments of node because unique-ing of targets of other nodes.
1244
-
1245
- Args:
1246
- node (Node): A Node whose arguments and keyword arguments to be updated.
1247
- """
1248
- result: {str: ScopedValue} = {}
1249
- if node.get_normalized_args() is None:
1250
- return
1251
- for key, arg in node.get_normalized_args().items():
1252
- if not isinstance(arg, ScopedValue):
1253
- raise TypeError("arg should be ScopedValue, got: ", type(arg))
1254
- if arg.type == ValueType.NamingValue:
1255
- # unique name
1256
- new_arg = ScopedValue(arg.type, arg.scope, self._target_namer.get_real_arg(arg.value))
1257
- result[key] = new_arg
1258
- else:
1259
- result[key] = arg
1260
- node.set_normalized_args(result)
1261
-
1262
- def _add_node2nodes(self, node: Node):
1263
- """
1264
- Add `node` to `_nodes` dict.
1265
-
1266
- Args:
1267
- node (Node): A Node to be added into `_nodes`.
1268
-
1269
- Raises:
1270
- RuntimeError: If name of the node is duplicated.
1271
- """
1272
- node_name = node.get_name()
1273
- if self._nodes.get(node_name) is not None:
1274
- raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
1275
- node)
1276
- self._nodes[node_name] = node
1277
-
1278
- def _insert_node(self, position: Optional[Position], node: Node):
1279
- """
1280
- Insert a node into SymbolTree.
1281
- 1. Add `node` to `_nodes`.
1282
- 2. Insert `node` to node list(source code order).
1283
- 3. Update topological relation and update inputs of `node`.
1284
-
1285
- Args:
1286
- position ([Position, optional]): Indicates node insert position. Position is None when inserting first node
1287
- of SymbolTree.
1288
- node (Node): A Node to be inserted into SymbolTree.
1289
-
1290
- Raises:
1291
- RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
1292
- unless inserting first node.
1293
- """
1294
- if position is None:
1295
- if self._nodes:
1296
- raise RuntimeError("self._nodes should be empty")
1297
- self._head = node
1298
- else:
1299
- if position.before_node:
1300
- position.node.insert_before(node)
1301
- else:
1302
- position.node.insert_after(node)
1303
- self._tail = node
1304
- self._add_node2nodes(node)
1305
- self._topo_mgr.on_insert_node(node)
1306
- node.set_belong_symbol_tree(self)
1307
-
1308
1205
  def _handle_custom_obj_in_normalized_args(self, node: Node):
1309
1206
  """
1310
- Convert CustomObjValue type argument to NamingValue type argument by storing custom object in global_vars dict.
1207
+ Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
1311
1208
 
1312
1209
  Args:
1313
1210
  node (Node): A Node whose arguments and keyword arguments to be handled.
1314
1211
  """
1315
- result: {str, ScopedValue} = {}
1316
- for arg, value in node.get_normalized_args().items():
1212
+ normalized_args: {str, ScopedValue} = {}
1213
+ for key, value in node.get_normalized_args().items():
1317
1214
  if not isinstance(value, ScopedValue):
1318
1215
  raise TypeError("value should be ScopedValue, got: ", type(value))
1319
1216
  if value.type == ValueType.CustomObjValue:
1320
- field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
1321
- setattr(self._origin_network, field, value.value)
1322
- init_targets = [ScopedValue.create_naming_value(field, "self")]
1323
- AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
1324
- result[arg] = init_targets[0]
1217
+ # Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
1218
+ arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
1219
+ setattr(self._origin_network, arg_name, value.value)
1220
+ # Add new code to __init__(): self.arg_name = obj.arg_name
1221
+ new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
1222
+ self._init_func_ast.body.append(new_ast)
1223
+ # Modify node's normalized_args: CustomObjValue -> self.arg_name
1224
+ normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
1325
1225
  else:
1326
- result[arg] = value
1327
- node.set_normalized_args(result)
1226
+ normalized_args[key] = value
1227
+ node.set_normalized_args(normalized_args)
1328
1228
 
1329
1229
  def _get_cls_through_file(self):
1330
1230
  """
@@ -1336,12 +1236,14 @@ class SymbolTree(Observer, Observable):
1336
1236
  Returns:
1337
1237
  A class handle.
1338
1238
  """
1339
- self._update_container()
1340
1239
  file_path = os.getcwd()
1341
1240
  file_path = os.path.join(file_path, "rewritten_network")
1342
1241
  if not os.path.exists(file_path):
1343
- os.mkdir(file_path)
1344
- file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
1242
+ try:
1243
+ os.mkdir(file_path, mode=0o700)
1244
+ except FileExistsError:
1245
+ pass
1246
+ file_name = f"{self._opt_cls_name}_{id(self)}.py"
1345
1247
  network_file = os.path.join(file_path, file_name)
1346
1248
  with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
1347
1249
  source = self.get_code()
@@ -1355,15 +1257,20 @@ class SymbolTree(Observer, Observable):
1355
1257
 
1356
1258
  i = 0
1357
1259
  while not tmp_module:
1358
- try:
1359
- tmp_module = importlib.import_module(tmp_module_name)
1360
- except ModuleNotFoundError:
1260
+ spec = importlib.util.spec_from_file_location(tmp_module_name, network_file)
1261
+ if spec:
1262
+ tmp_module = importlib.util.module_from_spec(spec)
1263
+ spec.loader.exec_module(tmp_module)
1264
+ else:
1265
+ logger.warning(f"load module {tmp_module_name} failed, retrying.")
1361
1266
  if i > 10:
1362
1267
  break
1363
- time.sleep(0.1)
1268
+ time.sleep(0.5)
1364
1269
  i += 1
1365
1270
  if not tmp_module:
1366
1271
  logger.error(f"load module {tmp_module_name} failed.")
1272
+ # Save new module to sys.modules to support inspect.getsource().
1273
+ sys.modules[tmp_module_name] = tmp_module
1367
1274
  network_cls = getattr(tmp_module, self._opt_cls_name)
1368
1275
  if network_cls is None:
1369
1276
  raise RuntimeError("Can not find network class:", self._opt_cls_name)
@@ -1373,21 +1280,6 @@ class SymbolTree(Observer, Observable):
1373
1280
  self._modified = True
1374
1281
  self.changed(event)
1375
1282
 
1376
- def _update_container(self):
1377
- """Update instance of node in container."""
1378
- for node in self.nodes():
1379
- index = 0
1380
- if node.get_node_type() == NodeType.CellContainer:
1381
- for n in node.node_list:
1382
- if not n.valid:
1383
- continue
1384
- if n.get_node_type() == NodeType.Tree:
1385
- obj = n.symbol_tree.get_network()
1386
- node.get_instance()[index] = obj
1387
- else:
1388
- node.get_instance()[index] = n.get_instance()
1389
- index += 1
1390
-
1391
1283
  def _cal_difference_set(self, input, other):
1392
1284
  """Calculate different set of two sets."""
1393
1285
  set1 = set(input)
@@ -1409,50 +1301,3 @@ class SymbolTree(Observer, Observable):
1409
1301
  primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
1410
1302
  for p in primitives:
1411
1303
  new_net._primitives[p] = self._origin_network._primitives[p]
1412
-
1413
- def _update_names_for_unique(self, node: ast.AST):
1414
- """ Update names of ast nodes for unique. """
1415
- if isinstance(node, (ast.For, ast.If, ast.While)):
1416
- self._update_names_for_unique_branchs(node)
1417
- elif isinstance(node, ast.Assign):
1418
- self._update_names_for_unique(node.value)
1419
- for target in node.targets:
1420
- self._update_names_for_unique(target)
1421
- elif isinstance(node, ast.Call):
1422
- if isinstance(node.func, ast.Attribute):
1423
- self._update_names_for_unique(node.func.value)
1424
- for arg in node.args:
1425
- self._update_names_for_unique(arg)
1426
- for keyword in node.keywords:
1427
- self._update_names_for_unique(keyword)
1428
- elif isinstance(node, ast.UnaryOp):
1429
- self._update_names_for_unique(node.operand)
1430
- elif isinstance(node, ast.BinOp):
1431
- self._update_names_for_unique(node.left)
1432
- self._update_names_for_unique(node.right)
1433
- elif isinstance(node, (ast.Attribute, ast.Subscript, ast.Return)):
1434
- self._update_names_for_unique(node.value)
1435
- elif isinstance(node, (ast.List, ast.Tuple)):
1436
- for elt in node.elts:
1437
- self._update_names_for_unique(elt)
1438
- elif isinstance(node, ast.Compare):
1439
- for comparator in node.comparators:
1440
- self._update_names_for_unique(comparator)
1441
- elif isinstance(node, ast.Name):
1442
- node.id = self._target_namer.get_real_arg(node.id)
1443
-
1444
- def _update_names_for_unique_branchs(self, node: Union[ast.For, ast.If, ast.While]):
1445
- """ Update names of ast nodes for unique with ast.For, ast.If or ast.While """
1446
- if isinstance(node, ast.For):
1447
- self._update_names_for_unique(node.target)
1448
- self._update_names_for_unique(node.iter)
1449
- for body in node.body:
1450
- self._update_names_for_unique(body)
1451
- for body in node.orelse:
1452
- self._update_names_for_unique(body)
1453
- elif isinstance(node, (ast.If, ast.While)):
1454
- self._update_names_for_unique(node.test)
1455
- for body in node.body:
1456
- self._update_names_for_unique(body)
1457
- for body in node.orelse:
1458
- self._update_names_for_unique(body)