mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/rewrite/symbol_tree.py
CHANGED
|
@@ -14,29 +14,31 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
|
16
16
|
import stat
|
|
17
|
-
from typing import Optional, Union, Tuple, Any
|
|
17
|
+
from typing import Optional, Union, Tuple, Any, Dict, List
|
|
18
18
|
import os
|
|
19
19
|
import sys
|
|
20
20
|
import ast
|
|
21
21
|
import importlib.util
|
|
22
|
-
import types
|
|
23
22
|
import time
|
|
24
|
-
import astunparse
|
|
25
23
|
|
|
26
24
|
from mindspore.nn import Cell
|
|
27
25
|
from mindspore import log as logger
|
|
28
|
-
from
|
|
29
|
-
from .node import Node, TreeNode
|
|
26
|
+
from .node.node import Node, TreeNode
|
|
30
27
|
from .api.node_type import NodeType
|
|
31
|
-
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder,
|
|
28
|
+
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
|
|
32
29
|
from .api.scoped_value import ScopedValue, ValueType
|
|
33
30
|
from .symbol_tree_dumper import SymbolTreeDumper
|
|
34
|
-
from .
|
|
31
|
+
from .node.node_topological_manager import TopoManager
|
|
35
32
|
from .namer import TargetNamer, NodeNamer, ClassNamer
|
|
36
33
|
from .common.observer import Observer
|
|
37
34
|
from .common.observable import Observable
|
|
38
35
|
from .common.event import Event
|
|
36
|
+
from .node.node_manager import NodeManager
|
|
39
37
|
|
|
38
|
+
if sys.version_info >= (3, 9):
|
|
39
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
40
|
+
else:
|
|
41
|
+
import astunparse
|
|
40
42
|
|
|
41
43
|
class Position:
|
|
42
44
|
"""
|
|
@@ -80,6 +82,7 @@ class FieldFinder(AstFinder):
|
|
|
80
82
|
Args:
|
|
81
83
|
scope (ast.AST): An instance of ast node as search scope.
|
|
82
84
|
"""
|
|
85
|
+
|
|
83
86
|
def __init__(self, scope: ast.AST):
|
|
84
87
|
super().__init__(scope)
|
|
85
88
|
self._result = False
|
|
@@ -133,7 +136,7 @@ class IfFixer(ast.NodeTransformer):
|
|
|
133
136
|
self.generic_visit(node)
|
|
134
137
|
|
|
135
138
|
|
|
136
|
-
class SymbolTree(Observer, Observable):
|
|
139
|
+
class SymbolTree(Observer, Observable, NodeManager):
|
|
137
140
|
"""
|
|
138
141
|
A symbol-tree usually corresponding to forward method of a network.
|
|
139
142
|
|
|
@@ -146,147 +149,135 @@ class SymbolTree(Observer, Observable):
|
|
|
146
149
|
"""
|
|
147
150
|
|
|
148
151
|
def __init__(self, origin_network: Cell, module_ast: ast.Module):
|
|
149
|
-
|
|
152
|
+
Observer.__init__(self)
|
|
150
153
|
Observable.__init__(self)
|
|
151
|
-
|
|
154
|
+
self._node_namer = NodeNamer()
|
|
155
|
+
self._node_namer.add_name('obj')
|
|
156
|
+
NodeManager.__init__(self, self._node_namer)
|
|
157
|
+
NodeManager.reg_observer(self, observer=self)
|
|
152
158
|
# init unique-namers
|
|
153
159
|
self._target_namer = TargetNamer()
|
|
154
|
-
|
|
155
|
-
# name or node would use as name of field, so name of origin network handler field should be added into \
|
|
156
|
-
# _node_name_namer.
|
|
157
|
-
self._node_name_namer.add_name(origin_network_key)
|
|
158
|
-
self._topo_mgr = TopoManager(self)
|
|
159
|
-
self._topo_mgr.reg_observer(self)
|
|
160
|
-
|
|
161
|
-
self._nodes: {str, Node} = {}
|
|
162
|
-
# parameters of forward method
|
|
163
|
-
self._inputs: [Node] = []
|
|
160
|
+
# input arguments of function
|
|
164
161
|
self._ori_cls_name = type(origin_network).__name__
|
|
165
162
|
self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
|
|
163
|
+
NodeManager.set_manager_name(self, self._opt_cls_name)
|
|
166
164
|
self._origin_network = origin_network
|
|
167
165
|
self._module_ast: ast.Module = module_ast
|
|
166
|
+
self._import_asts: Optional[ast.Ast] = []
|
|
168
167
|
self._class_ast: Optional[ast.ClassDef] = None
|
|
169
168
|
self._root_ast: Optional[ast.FunctionDef] = None
|
|
170
169
|
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
171
170
|
self._deleted_field = {}
|
|
172
171
|
self._deleted_node = []
|
|
173
|
-
self.
|
|
172
|
+
self._external_ast = []
|
|
174
173
|
self._father_class_ast = []
|
|
175
|
-
|
|
176
|
-
# head node is always point to the first node(in source code order) of SymbolTree
|
|
177
|
-
self._head = None
|
|
178
|
-
# tail node is always point to the last node(in source code order) of SymbolTree
|
|
179
|
-
self._tail = None
|
|
180
|
-
self._return: Optional[Node] = None
|
|
181
|
-
|
|
182
174
|
self._modified = False
|
|
183
|
-
self._node_visitor = None
|
|
184
|
-
|
|
185
175
|
self._tmp_file_limits = 20
|
|
186
176
|
self._tmp_files = []
|
|
187
177
|
self._saved_file_name = "./network_define.py"
|
|
188
178
|
# used to insert "sys.path.append(xxx)"
|
|
189
179
|
self._net_file_paths = []
|
|
180
|
+
self._tmp_import_strs = []
|
|
181
|
+
self._tmp_unmodified_strees: {type, str} = {}
|
|
182
|
+
self._tmp_replacers = []
|
|
183
|
+
# Record imported modules and names of each files
|
|
184
|
+
# The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
|
|
185
|
+
# Format: {file_path: {module: [name, ...], ...}, ...}
|
|
186
|
+
self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
|
|
190
187
|
|
|
191
188
|
def __del__(self):
|
|
192
189
|
for tmp_file in self._tmp_files:
|
|
193
190
|
tmp_file.close()
|
|
194
191
|
|
|
195
192
|
@staticmethod
|
|
196
|
-
def
|
|
197
|
-
"""
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
for node in nodes:
|
|
203
|
-
for arg in node.get_args():
|
|
204
|
-
if consumers.get(arg):
|
|
205
|
-
consumers[arg].append(node)
|
|
206
|
-
else:
|
|
207
|
-
consumers[arg] = [node]
|
|
208
|
-
for _, arg in node.get_kwargs():
|
|
209
|
-
if consumers.get(arg):
|
|
210
|
-
consumers[arg].append(node)
|
|
211
|
-
else:
|
|
212
|
-
consumers[arg] = [node]
|
|
213
|
-
for target in node.get_targets():
|
|
214
|
-
if providers.get(target) is not None:
|
|
215
|
-
raise RuntimeError(f"Target({target}) of node duplicated")
|
|
216
|
-
providers[target] = node
|
|
217
|
-
return consumers, providers
|
|
218
|
-
|
|
219
|
-
@staticmethod
|
|
220
|
-
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
|
|
221
|
-
"""Find all non-duplicated class name of SymbolTree recursively."""
|
|
222
|
-
replacer = AstReplacer(stree.get_class_ast())
|
|
223
|
-
replacers.append(replacer)
|
|
224
|
-
for node in stree.nodes():
|
|
225
|
-
if not isinstance(node, TreeNode):
|
|
193
|
+
def _remove_unused_import(module_ast):
|
|
194
|
+
"""remove unused import in self._module_ast"""
|
|
195
|
+
str_checker = StrChecker(module_ast)
|
|
196
|
+
for i in range(len(module_ast.body) - 1, -1, -1):
|
|
197
|
+
body = module_ast.body[i]
|
|
198
|
+
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
|
226
199
|
continue
|
|
227
|
-
if
|
|
200
|
+
if isinstance(body, ast.Import):
|
|
228
201
|
continue
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
# all modified ast.ClassDef should export to code
|
|
232
|
-
if sub_stree._modified:
|
|
233
|
-
allow_class_name.append(sub_stree._class_ast.name)
|
|
202
|
+
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
|
203
|
+
module_ast.body.remove(body)
|
|
234
204
|
continue
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
205
|
+
for alias in body.names:
|
|
206
|
+
name = alias.asname if alias.asname else alias.name
|
|
207
|
+
if not str_checker.check(name):
|
|
208
|
+
if len(body.names) == 1:
|
|
209
|
+
module_ast.body.remove(body)
|
|
210
|
+
i += 1
|
|
211
|
+
else:
|
|
212
|
+
body.names.remove(alias)
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _remove_duplicated_import(module_ast):
|
|
216
|
+
"""Remove duplicated import of 'net'."""
|
|
217
|
+
imports = set()
|
|
218
|
+
futures = set()
|
|
219
|
+
classes = set()
|
|
220
|
+
|
|
221
|
+
class TransImportNode(ast.NodeTransformer):
|
|
222
|
+
"""Find all import nodes from input ast node."""
|
|
223
|
+
|
|
224
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
225
|
+
class_str = astunparse.unparse(node)
|
|
226
|
+
if class_str not in classes:
|
|
227
|
+
classes.add(node.name)
|
|
228
|
+
return node
|
|
229
|
+
|
|
230
|
+
def visit_Try(self, node: ast.Try) -> Any:
|
|
231
|
+
if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
|
|
232
|
+
import_str = astunparse.unparse(node)
|
|
233
|
+
if import_str not in imports:
|
|
234
|
+
imports.add(import_str)
|
|
235
|
+
return node
|
|
236
|
+
|
|
237
|
+
def visit_Import(self, node: ast.Import) -> Any:
|
|
238
|
+
import_str = astunparse.unparse(node)
|
|
239
|
+
if import_str not in imports:
|
|
240
|
+
imports.add(import_str)
|
|
241
|
+
return node
|
|
242
|
+
|
|
243
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
244
|
+
"""
|
|
245
|
+
Once the father class 'A' is defined in the current module, all the next imported class 'A' should
|
|
246
|
+
be removed. e.g.
|
|
247
|
+
def class A():
|
|
248
|
+
...
|
|
249
|
+
from xxx import A, B
|
|
250
|
+
=>
|
|
251
|
+
def class A():
|
|
252
|
+
...
|
|
253
|
+
from xxx import B
|
|
254
|
+
"""
|
|
255
|
+
import_str = astunparse.unparse(node)
|
|
256
|
+
|
|
257
|
+
if import_str not in imports:
|
|
258
|
+
imports.add(import_str)
|
|
259
|
+
# remove "__future__" module
|
|
260
|
+
if node.module == '__future__':
|
|
261
|
+
futures.add(node.module)
|
|
262
|
+
return
|
|
263
|
+
# remove modules which have been defined in the code file
|
|
264
|
+
# it occurs when class A is a father class and other sub-classes import A
|
|
265
|
+
for alias in node.names[:]:
|
|
266
|
+
if alias.name in classes:
|
|
267
|
+
node.names.remove(alias)
|
|
268
|
+
# if the alias(es) in node.names are all removed, this import statement should be removed
|
|
269
|
+
if not node.names:
|
|
270
|
+
return
|
|
271
|
+
return node
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
get_node_handler = TransImportNode()
|
|
275
|
+
get_node_handler.generic_visit(module_ast)
|
|
242
276
|
|
|
243
277
|
def finish_build(self):
|
|
244
278
|
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
245
279
|
self.add_event(Event.TopologicalChangeEvent)
|
|
246
280
|
|
|
247
|
-
def create_assign_node(self, targets, func_name, args, kwargs):
|
|
248
|
-
"""
|
|
249
|
-
Create a ast.Assign type node.
|
|
250
|
-
|
|
251
|
-
Args:
|
|
252
|
-
targets (list): _description_
|
|
253
|
-
func_name (_type_): _description_
|
|
254
|
-
args (_type_): _description_
|
|
255
|
-
kwargs (_type_): _description_
|
|
256
|
-
|
|
257
|
-
Returns:
|
|
258
|
-
_type_: _description_
|
|
259
|
-
"""
|
|
260
|
-
# create targets
|
|
261
|
-
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
262
|
-
# create call
|
|
263
|
-
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
264
|
-
ast_args = ast_creator_registry.get("Args")(args)
|
|
265
|
-
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
266
|
-
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
267
|
-
# create assign
|
|
268
|
-
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
269
|
-
return ast_node
|
|
270
|
-
|
|
271
|
-
def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
|
|
272
|
-
'''
|
|
273
|
-
Instantiate an instance of node whose type is `CallFunction`.
|
|
274
|
-
|
|
275
|
-
Args:
|
|
276
|
-
node_name (str): Name of node.
|
|
277
|
-
func_name (str): Name of function.
|
|
278
|
-
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
279
|
-
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
280
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
281
|
-
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
282
|
-
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
283
|
-
class.
|
|
284
|
-
'''
|
|
285
|
-
logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
|
|
286
|
-
node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
|
|
287
|
-
node.set_belong_symbol_tree(self)
|
|
288
|
-
return node
|
|
289
|
-
|
|
290
281
|
def get_ori_cls_name(self) -> str:
|
|
291
282
|
"""
|
|
292
283
|
Get class name of original network.
|
|
@@ -342,6 +333,7 @@ class SymbolTree(Observer, Observable):
|
|
|
342
333
|
corresponding network class.
|
|
343
334
|
"""
|
|
344
335
|
self._root_ast = ast_node
|
|
336
|
+
NodeManager.set_ast_functiondef(self, ast_node)
|
|
345
337
|
|
|
346
338
|
def get_class_ast(self):
|
|
347
339
|
"""
|
|
@@ -380,18 +372,6 @@ class SymbolTree(Observer, Observable):
|
|
|
380
372
|
"""
|
|
381
373
|
self._init_func_ast = ast_node
|
|
382
374
|
|
|
383
|
-
def get_inputs(self):
|
|
384
|
-
return self._inputs
|
|
385
|
-
|
|
386
|
-
def get_head_node(self):
|
|
387
|
-
"""
|
|
388
|
-
Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
|
|
389
|
-
|
|
390
|
-
Returns:
|
|
391
|
-
An instance of node.
|
|
392
|
-
"""
|
|
393
|
-
return self._head
|
|
394
|
-
|
|
395
375
|
def get_origin_network(self):
|
|
396
376
|
"""
|
|
397
377
|
Getter of `_origin_network`.
|
|
@@ -405,46 +385,53 @@ class SymbolTree(Observer, Observable):
|
|
|
405
385
|
"""Get dict of nodes"""
|
|
406
386
|
return self._nodes
|
|
407
387
|
|
|
408
|
-
def
|
|
409
|
-
"""Get
|
|
410
|
-
return self.
|
|
388
|
+
def get_node_namer(self):
|
|
389
|
+
"""Get _node_namer"""
|
|
390
|
+
return self._node_namer
|
|
411
391
|
|
|
412
|
-
def
|
|
413
|
-
"""
|
|
414
|
-
|
|
415
|
-
self._net_file_paths.append(file_path)
|
|
416
|
-
|
|
417
|
-
def get_net_file_path(self):
|
|
418
|
-
"""Get _net_file_paths"""
|
|
419
|
-
return self._net_file_paths
|
|
392
|
+
def is_modified(self):
|
|
393
|
+
"""
|
|
394
|
+
Check whether symbol tree is modified.
|
|
420
395
|
|
|
421
|
-
|
|
396
|
+
Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
|
|
397
|
+
the symbol tree is created.
|
|
422
398
|
"""
|
|
423
|
-
|
|
399
|
+
return self._modified
|
|
424
400
|
|
|
425
|
-
|
|
426
|
-
A generator for iterating Nodes of `SymbolTree`.
|
|
401
|
+
def set_modified_true(self):
|
|
427
402
|
"""
|
|
428
|
-
|
|
429
|
-
nodes = []
|
|
430
|
-
node = self._head
|
|
431
|
-
while node is not None:
|
|
432
|
-
nodes.append(node)
|
|
433
|
-
node = node.get_next()
|
|
434
|
-
return iter(nodes)
|
|
403
|
+
Set self._modified true.
|
|
435
404
|
|
|
436
|
-
|
|
405
|
+
Self._modified is set true when 'if' exists in the original network.
|
|
406
|
+
In this situation, different original network instance tends to be different.
|
|
407
|
+
Hence, the class name should be updated.
|
|
437
408
|
"""
|
|
438
|
-
|
|
409
|
+
self._modified = True
|
|
439
410
|
|
|
440
|
-
|
|
441
|
-
|
|
411
|
+
def get_import_asts(self):
|
|
412
|
+
"""Get _import_asts"""
|
|
413
|
+
return self._import_asts
|
|
442
414
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
415
|
+
def get_external_ast(self):
|
|
416
|
+
"""Get _external_ast"""
|
|
417
|
+
return self._external_ast
|
|
418
|
+
|
|
419
|
+
def get_father_class_ast(self):
|
|
420
|
+
"""Get _father_class_ast"""
|
|
421
|
+
return self._father_class_ast
|
|
422
|
+
|
|
423
|
+
def get_imported_modules(self, file_path: str):
|
|
424
|
+
"""Get all modules and module_paths in file of `file_path` ."""
|
|
425
|
+
return self._imported_modules.get(file_path, {})
|
|
446
426
|
|
|
447
|
-
|
|
427
|
+
def save_imported_modules(self, file_path: str, module: str, names: List[str]):
|
|
428
|
+
"""Save module and names into _imported_modules."""
|
|
429
|
+
imported_modules = self.get_imported_modules(file_path)
|
|
430
|
+
if imported_modules.get(module):
|
|
431
|
+
imported_modules[module].extend(names)
|
|
432
|
+
else:
|
|
433
|
+
imported_modules[module] = names
|
|
434
|
+
self._imported_modules[file_path] = imported_modules
|
|
448
435
|
|
|
449
436
|
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
|
450
437
|
"""
|
|
@@ -535,9 +522,11 @@ class SymbolTree(Observer, Observable):
|
|
|
535
522
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
536
523
|
return Position.create(node.get_belong_symbol_tree(), node, False)
|
|
537
524
|
|
|
538
|
-
def insert_node(self,
|
|
525
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
|
|
526
|
+
insert_to_ast: bool = True):
|
|
539
527
|
"""
|
|
540
|
-
Insert a node
|
|
528
|
+
Insert a node before or after base_node.
|
|
529
|
+
|
|
541
530
|
Note:
|
|
542
531
|
Name of node will be unique while inserting node into SymbolTree.
|
|
543
532
|
|
|
@@ -556,57 +545,73 @@ class SymbolTree(Observer, Observable):
|
|
|
556
545
|
Topological relation is updated and inputs of corresponding node is updated.
|
|
557
546
|
|
|
558
547
|
Args:
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
548
|
+
new_node (Node): Node to be inserted.
|
|
549
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
550
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
551
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
552
|
+
NodeManager of symboltree's construct function.
|
|
553
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
563
554
|
|
|
564
555
|
Returns:
|
|
565
556
|
An instance of node which has been inserted into SymbolTree.
|
|
566
557
|
|
|
567
558
|
Raises:
|
|
568
559
|
ValueError: Node in the SymbolTree is inserted into SymbolTree again.
|
|
569
|
-
RuntimeError: If 'position' is not in current SymbolTree.
|
|
570
560
|
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
|
571
561
|
"""
|
|
572
|
-
if
|
|
573
|
-
raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
if
|
|
562
|
+
if new_node.get_belong_symbol_tree():
|
|
563
|
+
raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
|
|
564
|
+
|
|
565
|
+
# Check if base_node in current SymbolTree
|
|
566
|
+
if base_node is not None:
|
|
567
|
+
stree = base_node.get_belong_symbol_tree()
|
|
568
|
+
if stree is not None and stree is not self:
|
|
569
|
+
raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
|
|
570
|
+
f"current: {self.get_ori_cls_name()}.")
|
|
571
|
+
|
|
572
|
+
# Check if node is inserted between Input node
|
|
573
|
+
if base_node is not None and base_node.get_node_type() == NodeType.Input:
|
|
584
574
|
valid = True
|
|
585
|
-
if
|
|
575
|
+
if before_node:
|
|
586
576
|
valid = False
|
|
587
|
-
if
|
|
577
|
+
if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
|
|
588
578
|
valid = False
|
|
589
579
|
if not valid:
|
|
590
|
-
raise RuntimeError("Can not insert a node before or between parameters:",
|
|
591
|
-
|
|
592
|
-
node_name = self._node_name_namer.get_name(node)
|
|
593
|
-
node.set_name(node_name)
|
|
580
|
+
raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
|
|
581
|
+
|
|
594
582
|
# save target name, which is used to provide unique target
|
|
595
|
-
if
|
|
596
|
-
for target in
|
|
583
|
+
if new_node.get_targets():
|
|
584
|
+
for target in new_node.get_targets():
|
|
597
585
|
self._target_namer.add_name(str(target))
|
|
598
|
-
self._handle_custom_obj_in_normalized_args(node)
|
|
599
|
-
self._insert_node(position, node)
|
|
600
|
-
if isinstance(node, TreeNode):
|
|
601
|
-
node.symbol_tree.reg_observer(self)
|
|
602
|
-
if self._node_visitor:
|
|
603
|
-
self._node_visitor.append_node(node)
|
|
604
|
-
# update init-function-ast and construct-function-ast
|
|
605
|
-
if insert_to_ast:
|
|
606
|
-
self._insert_to_ast_while_insert_node(node, position)
|
|
607
|
-
return node
|
|
608
586
|
|
|
609
|
-
|
|
587
|
+
self._handle_custom_obj_in_normalized_args(new_node)
|
|
588
|
+
|
|
589
|
+
# Insert node into NodeManager
|
|
590
|
+
if node_manager is None:
|
|
591
|
+
if base_node is None:
|
|
592
|
+
raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
|
|
593
|
+
node_manager = base_node.get_node_manager()
|
|
594
|
+
|
|
595
|
+
# set node's _belong_symbol_tree
|
|
596
|
+
new_node.set_belong_symbol_tree(self)
|
|
597
|
+
|
|
598
|
+
if node_manager is self:
|
|
599
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
600
|
+
if insert_to_ast:
|
|
601
|
+
# update init-function-ast and construct-function-ast
|
|
602
|
+
self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
|
|
603
|
+
else:
|
|
604
|
+
node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
|
|
605
|
+
|
|
606
|
+
# register code changed event observer, which is used to update _modified flag.
|
|
607
|
+
if new_node.get_node_type() == NodeType.Tree:
|
|
608
|
+
new_node.symbol_tree.reg_observer(self)
|
|
609
|
+
elif isinstance(new_node, NodeManager):
|
|
610
|
+
new_node.reg_observer(self)
|
|
611
|
+
|
|
612
|
+
return new_node
|
|
613
|
+
|
|
614
|
+
def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
|
|
610
615
|
"""
|
|
611
616
|
Append a node to SymbolTree.
|
|
612
617
|
|
|
@@ -614,13 +619,17 @@ class SymbolTree(Observer, Observable):
|
|
|
614
619
|
node (Node): An instance of node to be appended.
|
|
615
620
|
append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
616
621
|
True.
|
|
622
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
623
|
+
NodeManager of symboltree's construct function.
|
|
617
624
|
|
|
618
625
|
Returns:
|
|
619
626
|
An instance of node which has been appended to SymbolTree.
|
|
620
627
|
"""
|
|
621
|
-
|
|
628
|
+
if node_manager is None:
|
|
629
|
+
node_manager = self
|
|
630
|
+
return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
|
|
622
631
|
|
|
623
|
-
def append_origin_field(self, node: Node) -> Node:
|
|
632
|
+
def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
|
|
624
633
|
"""
|
|
625
634
|
Append an original field node to SymbolTree. An original field node represents a node created from existing
|
|
626
635
|
statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
|
|
@@ -629,26 +638,16 @@ class SymbolTree(Observer, Observable):
|
|
|
629
638
|
|
|
630
639
|
Args:
|
|
631
640
|
node (Node): An instance of node to be appended.
|
|
641
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
642
|
+
NodeManager of symboltree's construct function.
|
|
632
643
|
|
|
633
644
|
Returns:
|
|
634
645
|
An instance of node which has been appended to SymbolTree.
|
|
635
646
|
"""
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
elif node.get_node_type() == NodeType.Tree:
|
|
641
|
-
# add father_class_ast into main tree, used when get_code
|
|
642
|
-
for father_ast in node.symbol_tree.get_father_class_ast():
|
|
643
|
-
if father_ast not in self._father_class_ast:
|
|
644
|
-
self._father_class_ast.append(father_ast)
|
|
645
|
-
# add subtree's net path into main tree
|
|
646
|
-
for file_path in node.symbol_tree.get_net_file_path():
|
|
647
|
-
if file_path not in self._net_file_paths:
|
|
648
|
-
self.append_net_file_path(file_path)
|
|
649
|
-
return self.append_node(node, False)
|
|
650
|
-
|
|
651
|
-
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None):
|
|
647
|
+
return self.append_node(node, node_manager, False)
|
|
648
|
+
|
|
649
|
+
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
|
|
650
|
+
node_manager: NodeManager = None):
|
|
652
651
|
"""
|
|
653
652
|
Append an input node to SymbolTree corresponding to parameter of forward method of network class.
|
|
654
653
|
This method is called while building SymbolTree usually.
|
|
@@ -658,13 +657,18 @@ class SymbolTree(Observer, Observable):
|
|
|
658
657
|
param_name (str): A str represents name of parameter of forward method of network class.
|
|
659
658
|
default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
|
|
660
659
|
means parameter has no default value.
|
|
660
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
661
|
+
NodeManager of symboltree's construct function.
|
|
661
662
|
|
|
662
663
|
Returns:
|
|
663
664
|
An instance of input node which has been appended to SymbolTree.
|
|
664
665
|
"""
|
|
665
666
|
if param_name == "self":
|
|
666
667
|
return
|
|
667
|
-
|
|
668
|
+
# check param_name duplicated
|
|
669
|
+
if node_manager is None:
|
|
670
|
+
node_manager = self
|
|
671
|
+
for input_node in node_manager._inputs:
|
|
668
672
|
targets = input_node.get_targets()
|
|
669
673
|
if len(targets) != 1:
|
|
670
674
|
raise RuntimeError("targets should have 1 elements")
|
|
@@ -677,9 +681,10 @@ class SymbolTree(Observer, Observable):
|
|
|
677
681
|
if exist_param == param_name:
|
|
678
682
|
raise RuntimeError("input duplicated:", param_name)
|
|
679
683
|
input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
|
|
680
|
-
self.append_origin_field(input_node)
|
|
684
|
+
self.append_origin_field(input_node, node_manager)
|
|
681
685
|
|
|
682
|
-
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST
|
|
686
|
+
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
|
|
687
|
+
node_manager: NodeManager = None) -> Optional[Node]:
|
|
683
688
|
"""
|
|
684
689
|
Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
|
|
685
690
|
a list or a dict.
|
|
@@ -688,6 +693,8 @@ class SymbolTree(Observer, Observable):
|
|
|
688
693
|
Args:
|
|
689
694
|
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
690
695
|
ast_node (ast.AST): A ast node represents ast node.
|
|
696
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
697
|
+
NodeManager of symboltree's construct function.
|
|
691
698
|
|
|
692
699
|
Returns:
|
|
693
700
|
An instance of python node if a new node has been appended to SymbolTree else None.
|
|
@@ -696,9 +703,9 @@ class SymbolTree(Observer, Observable):
|
|
|
696
703
|
return None
|
|
697
704
|
if isinstance(ast_node, (list, dict)) and not ast_node:
|
|
698
705
|
return None
|
|
699
|
-
return self.append_python_node(ast_scope, ast_node)
|
|
706
|
+
return self.append_python_node(ast_scope, ast_node, node_manager)
|
|
700
707
|
|
|
701
|
-
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
|
|
708
|
+
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
|
|
702
709
|
"""
|
|
703
710
|
Append a python node to SymbolTree.
|
|
704
711
|
This method is called while building SymbolTree usually.
|
|
@@ -706,39 +713,50 @@ class SymbolTree(Observer, Observable):
|
|
|
706
713
|
Args:
|
|
707
714
|
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
708
715
|
ast_node (ast.AST): A ast node represents ast node.
|
|
716
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
717
|
+
NodeManager of symboltree's construct function.
|
|
709
718
|
|
|
710
719
|
Returns:
|
|
711
720
|
An instance of python node which has been appended to SymbolTree.
|
|
712
721
|
"""
|
|
713
722
|
logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
|
|
714
|
-
node_name =
|
|
723
|
+
node_name = type(ast_node).__name__
|
|
715
724
|
node = Node.create_python_node(ast_node, node_name)
|
|
716
|
-
|
|
725
|
+
if node_manager is None or node_manager is self:
|
|
726
|
+
NodeManager.append_python_node(self, node)
|
|
727
|
+
else:
|
|
728
|
+
node_manager.append_python_node(node)
|
|
717
729
|
return node
|
|
718
730
|
|
|
719
|
-
def set_output(self, return_value: str,
|
|
731
|
+
def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
|
|
732
|
+
node_manager: NodeManager = None) -> Node:
|
|
720
733
|
"""
|
|
721
734
|
Update return value of return of forward method of network class.
|
|
722
735
|
|
|
723
736
|
Args:
|
|
724
737
|
return_value (str): A str represents new return value.
|
|
725
|
-
|
|
738
|
+
arg_index (int): A int indicates which value in return to be updated.
|
|
739
|
+
return_idx (int): A int indicates which return node to be updated. Default: 0.
|
|
740
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
|
|
741
|
+
symboltree's construct function.
|
|
726
742
|
|
|
727
743
|
Returns:
|
|
728
744
|
An instance of node represents return node after updated.
|
|
729
745
|
"""
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
746
|
+
node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
|
|
747
|
+
if not node_returns:
|
|
748
|
+
raise RuntimeError("Current node_manager has no output")
|
|
749
|
+
if return_idx >= len(node_returns):
|
|
750
|
+
raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
|
|
751
|
+
node_return = node_returns[return_idx]
|
|
752
|
+
self.set_node_arg(node_return, arg_index, return_value)
|
|
753
|
+
return node_return
|
|
734
754
|
|
|
735
755
|
def erase_node(self, node_or_name: Union[Node, str]) -> Node:
|
|
736
756
|
"""
|
|
737
757
|
Erase a node from SymbolTree.
|
|
738
|
-
Note:
|
|
739
|
-
If node is depended on by other node, RuntimeError will raise.
|
|
740
758
|
|
|
741
|
-
|
|
759
|
+
Topological relation will be updated.
|
|
742
760
|
|
|
743
761
|
Args:
|
|
744
762
|
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
@@ -754,19 +772,21 @@ class SymbolTree(Observer, Observable):
|
|
|
754
772
|
node = self._get_real_node(node_or_name)
|
|
755
773
|
if node is None:
|
|
756
774
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
775
|
+
# erase node in NodeManager
|
|
776
|
+
node_manager = node.get_node_manager()
|
|
777
|
+
|
|
778
|
+
logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
|
|
779
|
+
f"node_manager: {node_manager.get_manager_name()}, "
|
|
780
|
+
f"code: {astunparse.unparse(node.get_ast()).strip()}, "
|
|
781
|
+
f"node_name:{node.get_name()}")
|
|
782
|
+
|
|
783
|
+
if node_manager is self:
|
|
784
|
+
NodeManager.erase_node(self, node)
|
|
785
|
+
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
786
|
+
if not ret:
|
|
787
|
+
raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
788
|
+
else:
|
|
789
|
+
node_manager.erase_node(node)
|
|
770
790
|
self._deleted_node.append(node.get_name())
|
|
771
791
|
return node
|
|
772
792
|
|
|
@@ -785,26 +805,17 @@ class SymbolTree(Observer, Observable):
|
|
|
785
805
|
RuntimeError: If 'old_node' is isolated.
|
|
786
806
|
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
|
787
807
|
"""
|
|
788
|
-
|
|
789
|
-
if hasattr(old_node, "container"):
|
|
790
|
-
self._replace_container_node(old_node, new_nodes)
|
|
791
|
-
return new_nodes[0]
|
|
792
808
|
real_old_node = self._get_real_node(old_node)
|
|
793
809
|
if real_old_node is None:
|
|
794
810
|
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
|
795
|
-
#
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
raise RuntimeError("Try replacing a isolated node: ", old_node)
|
|
800
|
-
if prev_node is None:
|
|
801
|
-
position = self.before(next_node)
|
|
802
|
-
else:
|
|
803
|
-
position = self.after(prev_node)
|
|
811
|
+
# insert new_nodes into node_manager
|
|
812
|
+
node_manager = real_old_node.get_node_manager()
|
|
813
|
+
# insert new_nodes into NodeManager
|
|
814
|
+
base_node = old_node
|
|
804
815
|
for node in new_nodes:
|
|
805
|
-
self.insert_node(
|
|
806
|
-
|
|
807
|
-
self.erase_node(old_node)
|
|
816
|
+
self.insert_node(node, base_node, False, node_manager, True)
|
|
817
|
+
base_node = node
|
|
818
|
+
_ = self.erase_node(old_node)
|
|
808
819
|
return new_nodes[-1]
|
|
809
820
|
|
|
810
821
|
def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
|
|
@@ -868,6 +879,15 @@ class SymbolTree(Observer, Observable):
|
|
|
868
879
|
"""Get a unique name in the symboltree"""
|
|
869
880
|
return self._target_namer.get_name(name)
|
|
870
881
|
|
|
882
|
+
def unique_func_name(self, name: str):
|
|
883
|
+
"""Get a unique function name in the symboltree"""
|
|
884
|
+
if not hasattr(self._origin_network, name):
|
|
885
|
+
return name
|
|
886
|
+
suffix = 1
|
|
887
|
+
while hasattr(self._origin_network, f"{name}_{suffix}"):
|
|
888
|
+
suffix += 1
|
|
889
|
+
return f"{name}_{suffix}"
|
|
890
|
+
|
|
871
891
|
def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
|
|
872
892
|
"""
|
|
873
893
|
Set target of `node` .
|
|
@@ -895,21 +915,191 @@ class SymbolTree(Observer, Observable):
|
|
|
895
915
|
node.set_targets(targets)
|
|
896
916
|
self._topo_mgr.on_update_target(node, index, old_target, target)
|
|
897
917
|
|
|
898
|
-
def
|
|
899
|
-
|
|
918
|
+
def all_nodes(self):
|
|
919
|
+
"""
|
|
920
|
+
Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
|
|
921
|
+
|
|
922
|
+
Returns:
|
|
923
|
+
A list of nodes.
|
|
924
|
+
"""
|
|
925
|
+
nodes = []
|
|
926
|
+
node_managers = [self]
|
|
927
|
+
while node_managers:
|
|
928
|
+
node_manager = node_managers.pop()
|
|
929
|
+
nodes.extend(node_manager.nodes())
|
|
930
|
+
for node in node_manager.nodes():
|
|
931
|
+
if isinstance(node, NodeManager):
|
|
932
|
+
node_managers.append(node)
|
|
933
|
+
for tree_node in self.get_tree_nodes():
|
|
934
|
+
stree = tree_node.symbol_tree
|
|
935
|
+
nodes.extend(stree.all_nodes())
|
|
936
|
+
return nodes
|
|
937
|
+
|
|
938
|
+
def get_node_from_name(self, node_name: str):
|
|
939
|
+
"""
|
|
940
|
+
Get node from all NodeManagers in current symbol tree by `node_name`.
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
node_name (str): A str represents name of node as key of query.
|
|
944
|
+
|
|
945
|
+
Returns:
|
|
946
|
+
An instance of Node if found else None.
|
|
947
|
+
"""
|
|
948
|
+
node_managers = [self]
|
|
949
|
+
while node_managers:
|
|
950
|
+
node_manager = node_managers.pop()
|
|
951
|
+
node = node_manager.get_node(node_name)
|
|
952
|
+
if node:
|
|
953
|
+
return node
|
|
954
|
+
for node in node_manager.nodes():
|
|
955
|
+
if isinstance(node, NodeManager):
|
|
956
|
+
node_managers.append(node)
|
|
957
|
+
return None
|
|
958
|
+
|
|
959
|
+
def print_node_tabulate(self, all_nodes: bool = False):
|
|
960
|
+
"""
|
|
961
|
+
Print nodes information and nodes' topological relations.
|
|
962
|
+
|
|
963
|
+
Args:
|
|
964
|
+
all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
|
|
965
|
+
nodes, CellContainer nodes and sub symbol trees.
|
|
966
|
+
"""
|
|
967
|
+
try:
|
|
968
|
+
from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
|
|
969
|
+
except ImportError:
|
|
970
|
+
logger.warning("print_node_tabulate relies on the library `tabulate`, "
|
|
971
|
+
"which could not be found on this machine. Run `pip "
|
|
972
|
+
"install tabulate` to install the library.")
|
|
973
|
+
return ""
|
|
974
|
+
print(NodeManager.dump(self, self.get_manager_name()))
|
|
975
|
+
if all_nodes:
|
|
976
|
+
node_managers = [self]
|
|
977
|
+
while node_managers:
|
|
978
|
+
node_manager = node_managers.pop()
|
|
979
|
+
for node in node_manager.nodes():
|
|
980
|
+
if isinstance(node, NodeManager):
|
|
981
|
+
print(node.dump(node.get_manager_name()))
|
|
982
|
+
node_managers.append(node)
|
|
983
|
+
for tree_node in self.get_tree_nodes():
|
|
984
|
+
stree = tree_node.symbol_tree
|
|
985
|
+
stree.print_node_tabulate(all_nodes)
|
|
900
986
|
|
|
901
987
|
def dump(self):
|
|
902
988
|
"""Dump graph."""
|
|
903
989
|
dump_st = SymbolTreeDumper(self)
|
|
904
990
|
dump_st.dump()
|
|
905
991
|
|
|
906
|
-
def
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
992
|
+
def check_body_exist(self, body, code_bodies):
|
|
993
|
+
"""Check whether body already exist in code_bodies"""
|
|
994
|
+
# Check import ast node exist by saving import code string to self._tmp_import_strs
|
|
995
|
+
if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
|
|
996
|
+
import_str = astunparse.unparse(body)
|
|
997
|
+
if import_str in self._tmp_import_strs:
|
|
998
|
+
return True
|
|
999
|
+
self._tmp_import_strs.append(import_str)
|
|
1000
|
+
return False
|
|
1001
|
+
|
|
1002
|
+
# Check ClassDef ast node exist by using AstClassFinder
|
|
1003
|
+
if isinstance(body, ast.ClassDef):
|
|
1004
|
+
if sys.version_info >= (3, 9):
|
|
1005
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1006
|
+
else:
|
|
1007
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies))
|
|
1008
|
+
results = class_finder.find_all(body.name)
|
|
1009
|
+
return bool(results)
|
|
1010
|
+
|
|
1011
|
+
# Check FunctionDef ast node exist by using AstFunctionFinder
|
|
1012
|
+
if isinstance(body, ast.FunctionDef):
|
|
1013
|
+
if sys.version_info >= (3, 9):
|
|
1014
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1015
|
+
else:
|
|
1016
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
|
|
1017
|
+
results = function_finder.find_all(body.name)
|
|
1018
|
+
return bool(results)
|
|
1019
|
+
|
|
1020
|
+
return False
|
|
1021
|
+
|
|
1022
|
+
def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
|
|
1023
|
+
"""
|
|
1024
|
+
For the unmodified symbol tree, only one definition code remains in the generated code.
|
|
1025
|
+
Everywhere else calling this symbol tree will use the class in this definition code.
|
|
1026
|
+
"""
|
|
1027
|
+
# all modified ast.ClassDef will be exported to code
|
|
1028
|
+
if stree.is_modified():
|
|
1029
|
+
return False
|
|
1030
|
+
# all un-modified ast.ClassDef only keep one instance
|
|
1031
|
+
first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
|
|
1032
|
+
if first_cls_name is None:
|
|
1033
|
+
class_ast = stree.get_class_ast()
|
|
1034
|
+
if class_ast:
|
|
1035
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
|
|
1036
|
+
return False
|
|
1037
|
+
# Un-modified ast.ClassDef already exist in code_bodies,
|
|
1038
|
+
# replace class name to class name of first un-modified ast.ClassDef.
|
|
1039
|
+
if sys.version_info >= (3, 9):
|
|
1040
|
+
replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1041
|
+
else:
|
|
1042
|
+
replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1043
|
+
replacer.replace_all(stree.get_class_ast().name, first_cls_name)
|
|
1044
|
+
self._tmp_replacers.append(replacer)
|
|
1045
|
+
return True
|
|
1046
|
+
|
|
1047
|
+
def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
|
|
1048
|
+
"""
|
|
1049
|
+
Convert nodes in stree to code_bodies
|
|
1050
|
+
|
|
1051
|
+
1. Add import asts into code_bodies
|
|
1052
|
+
2. Add class, function and other type of asts into code_bodies
|
|
1053
|
+
3. Add father class asts into code_bodies
|
|
1054
|
+
4. Add external function asts into code_bodies
|
|
1055
|
+
5. Add subtrees to code_bodies
|
|
1056
|
+
5.1 Add subtrees in construct to code_bodies
|
|
1057
|
+
5.2 Add subtrees in CellContainers to code_bodies
|
|
1058
|
+
|
|
1059
|
+
"""
|
|
1060
|
+
# Add import asts into code_bodies
|
|
1061
|
+
for body in stree.get_import_asts():
|
|
1062
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1063
|
+
code_bodies.insert(insert_pos, body)
|
|
1064
|
+
insert_pos += 1
|
|
1065
|
+
|
|
1066
|
+
# Add class, function and other type of asts into code_bodies
|
|
1067
|
+
if stree.get_module_ast():
|
|
1068
|
+
for body in stree.get_module_ast().body:
|
|
1069
|
+
if self.check_body_exist(body, code_bodies):
|
|
1070
|
+
continue
|
|
1071
|
+
if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
|
|
1072
|
+
code_bodies.insert(insert_pos, body)
|
|
1073
|
+
else:
|
|
1074
|
+
code_bodies.append(body)
|
|
1075
|
+
|
|
1076
|
+
# Add father class asts into code_bodies
|
|
1077
|
+
for body in reversed(stree.get_father_class_ast()):
|
|
1078
|
+
if self.check_body_exist(body, code_bodies):
|
|
1079
|
+
# remove exist ast in old position, then insert ast to upper position
|
|
1080
|
+
if sys.version_info >= (3, 9):
|
|
1081
|
+
exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
|
|
1082
|
+
else:
|
|
1083
|
+
exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
|
|
1084
|
+
code_bodies.remove(exist_ast)
|
|
1085
|
+
code_bodies.insert(insert_pos, body)
|
|
1086
|
+
|
|
1087
|
+
# Add external asts into code_bodies
|
|
1088
|
+
for body in stree.get_external_ast():
|
|
1089
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1090
|
+
code_bodies.insert(insert_pos, body)
|
|
1091
|
+
insert_pos += 1
|
|
1092
|
+
|
|
1093
|
+
# Add subtrees to code_bodies
|
|
1094
|
+
for node in stree.get_tree_nodes():
|
|
1095
|
+
sub_stree = node.symbol_tree
|
|
1096
|
+
# Ignore TreeNode create by function in the class
|
|
1097
|
+
if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
|
|
1098
|
+
continue
|
|
1099
|
+
# For the unmodified class, update class name to name of first class
|
|
1100
|
+
if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
|
|
1101
|
+
continue
|
|
1102
|
+
self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
|
|
913
1103
|
|
|
914
1104
|
def get_code(self) -> str:
|
|
915
1105
|
"""
|
|
@@ -918,34 +1108,22 @@ class SymbolTree(Observer, Observable):
|
|
|
918
1108
|
Returns:
|
|
919
1109
|
A str represents source code of modified network.
|
|
920
1110
|
"""
|
|
921
|
-
self.
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
self.
|
|
1111
|
+
self._tmp_import_strs.clear()
|
|
1112
|
+
self._tmp_unmodified_strees.clear()
|
|
1113
|
+
self._tmp_replacers.clear()
|
|
1114
|
+
code_bodies = []
|
|
1115
|
+
self.convert_stree_to_code_bodies(self, code_bodies)
|
|
1116
|
+
if sys.version_info >= (3, 9):
|
|
1117
|
+
gencode_module = ast.Module(body=code_bodies, type_ignores=[])
|
|
1118
|
+
else:
|
|
1119
|
+
gencode_module = ast.Module(body=code_bodies)
|
|
1120
|
+
SymbolTree._remove_unused_import(gencode_module)
|
|
1121
|
+
SymbolTree._remove_duplicated_import(gencode_module)
|
|
926
1122
|
ast.fix_missing_locations(self._module_ast)
|
|
927
|
-
|
|
928
|
-
# Replace duplicated ast.ClassDef reference in main-ClassDef
|
|
929
|
-
seen_class: {type, str} = {}
|
|
930
|
-
allow_class_name = [self._class_ast.name]
|
|
931
|
-
replacers = []
|
|
932
|
-
SymbolTree._find_all_class_in_symboltree(self, seen_class, allow_class_name, replacers)
|
|
933
|
-
# Add all non-ClassDef body to gencode_module
|
|
934
|
-
# Add all ClassDef in allow_class_name to gencode_module
|
|
935
|
-
# Use gencode_module to generate code
|
|
936
|
-
bodies = []
|
|
937
|
-
for body in self._module_ast.body:
|
|
938
|
-
if not isinstance(body, ast.ClassDef):
|
|
939
|
-
bodies.append(body)
|
|
940
|
-
continue
|
|
941
|
-
if body.name in allow_class_name:
|
|
942
|
-
bodies.append(body)
|
|
943
|
-
gencode_module = ast.Module(body=bodies)
|
|
944
|
-
if_fixer = IfFixer()
|
|
945
|
-
if_fixer.fix(gencode_module)
|
|
1123
|
+
IfFixer().fix(gencode_module)
|
|
946
1124
|
code = astunparse.unparse(gencode_module)
|
|
947
|
-
#
|
|
948
|
-
for replacer in
|
|
1125
|
+
# Revert the class name to its original state
|
|
1126
|
+
for replacer in self._tmp_replacers:
|
|
949
1127
|
replacer.undo_all()
|
|
950
1128
|
return code
|
|
951
1129
|
|
|
@@ -979,251 +1157,71 @@ class SymbolTree(Observer, Observable):
|
|
|
979
1157
|
f.write(source.encode('utf-8'))
|
|
980
1158
|
f.flush()
|
|
981
1159
|
|
|
982
|
-
def
|
|
1160
|
+
def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
|
|
1161
|
+
node_manager: NodeManager):
|
|
983
1162
|
""" insert_to_ast_while_insert_node. """
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
992
|
-
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
993
|
-
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
994
|
-
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
995
|
-
|
|
996
|
-
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
997
|
-
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
998
|
-
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
999
|
-
|
|
1000
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1001
|
-
None if position is None else position.node.get_ast(),
|
|
1002
|
-
position.before_node)
|
|
1003
|
-
sub_stree: SymbolTree = node.symbol_tree
|
|
1004
|
-
from .symbol_tree_builder import SymbolTreeBuilder
|
|
1005
|
-
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
1163
|
+
if new_node.get_node_type() == NodeType.Input:
|
|
1164
|
+
# insert a new input
|
|
1165
|
+
self._inputs.append(new_node)
|
|
1166
|
+
ast_construct = self.get_ast_root()
|
|
1167
|
+
arg: str = new_node.get_targets()[0].value
|
|
1168
|
+
ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
|
|
1169
|
+
AstModifier.append_arg_to_function(ast_construct, ast_arg)
|
|
1006
1170
|
else:
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
body = self._module_ast.body[i]
|
|
1022
|
-
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
|
1023
|
-
continue
|
|
1024
|
-
if isinstance(body, ast.Import):
|
|
1025
|
-
continue
|
|
1026
|
-
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
|
1027
|
-
self._module_ast.body.remove(body)
|
|
1028
|
-
continue
|
|
1029
|
-
for alias in body.names:
|
|
1030
|
-
name = alias.asname if alias.asname else alias.name
|
|
1031
|
-
if not str_checker.check(name):
|
|
1032
|
-
if len(body.names) == 1:
|
|
1033
|
-
self._module_ast.body.remove(body)
|
|
1034
|
-
i += 1
|
|
1035
|
-
else:
|
|
1036
|
-
body.names.remove(alias)
|
|
1037
|
-
|
|
1038
|
-
def _replace_container_node(self, old_node, new_nodes):
|
|
1039
|
-
cellcontainer = getattr(old_node, "container")
|
|
1040
|
-
index = cellcontainer.node_list.index(old_node)
|
|
1041
|
-
for n in reversed(new_nodes):
|
|
1042
|
-
cellcontainer.insert(index, n)
|
|
1043
|
-
index = cellcontainer.node_list.index(old_node)
|
|
1044
|
-
cellcontainer.erase(old_node)
|
|
1045
|
-
|
|
1046
|
-
def _filter_out_to_delete_field(self, to_delete_field):
|
|
1047
|
-
"""filter out used field from `to_delete_field`"""
|
|
1048
|
-
for func_def in self._class_ast.body:
|
|
1049
|
-
if not isinstance(func_def, ast.FunctionDef):
|
|
1050
|
-
continue
|
|
1051
|
-
if func_def.name != "__init__":
|
|
1052
|
-
to_delete_to_delete_keys = []
|
|
1053
|
-
property_checker = CheckPropertyIsUsed(func_def)
|
|
1054
|
-
for key, _ in self._deleted_field.items():
|
|
1055
|
-
if property_checker.check("self", key):
|
|
1056
|
-
to_delete_to_delete_keys.append(key)
|
|
1057
|
-
property_checker = CheckPropertyIsUsed(func_def)
|
|
1058
|
-
for key in to_delete_to_delete_keys:
|
|
1059
|
-
self._deleted_field.pop(key)
|
|
1171
|
+
# insert a new assign statement
|
|
1172
|
+
ast_assign = new_node.get_ast()
|
|
1173
|
+
if ast_assign is None:
|
|
1174
|
+
func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
|
|
1175
|
+
new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
|
|
1176
|
+
ast_assign = new_node.update_ast_node()
|
|
1177
|
+
if not isinstance(ast_assign, ast.Assign):
|
|
1178
|
+
raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
|
|
1179
|
+
# Save instance into _origin_network.
|
|
1180
|
+
setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
|
|
1181
|
+
# Insert ast to __init__ function
|
|
1182
|
+
if isinstance(new_node, TreeNode):
|
|
1183
|
+
init_code = f"self.{new_node.get_name()} = " \
|
|
1184
|
+
f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
|
|
1060
1185
|
else:
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
self._deleted_field.pop(key)
|
|
1072
|
-
|
|
1073
|
-
def _remove_unused_field(self):
|
|
1074
|
-
"""remove unused field in __init__ function"""
|
|
1075
|
-
multi_targets = []
|
|
1076
|
-
self._deleted_field = {}
|
|
1077
|
-
for index, body in enumerate(self._init_func_ast.body):
|
|
1078
|
-
if not isinstance(body, ast.Assign):
|
|
1079
|
-
continue
|
|
1080
|
-
targets = body.targets
|
|
1081
|
-
for target in targets:
|
|
1082
|
-
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
|
|
1083
|
-
and target.value.id == "self":
|
|
1084
|
-
self._deleted_field[target.attr] = index
|
|
1085
|
-
if len(targets) > 1:
|
|
1086
|
-
multi_targets.append(index)
|
|
1087
|
-
self._filter_out_to_delete_field(self._deleted_field)
|
|
1088
|
-
for i in range(len(self._init_func_ast.body) - 1, -1, -1):
|
|
1089
|
-
if i in self._deleted_field.values():
|
|
1090
|
-
if i in multi_targets:
|
|
1091
|
-
raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
|
|
1092
|
-
AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
|
|
1093
|
-
ast.fix_missing_locations(self._init_func_ast)
|
|
1094
|
-
|
|
1095
|
-
def _remove_duplicated_import(self):
|
|
1096
|
-
"""Remove duplicated import of 'net'."""
|
|
1097
|
-
imports = []
|
|
1098
|
-
for body in self._module_ast.body:
|
|
1099
|
-
if isinstance(body, (ast.ImportFrom, ast.Import)):
|
|
1100
|
-
import_str = astunparse.unparse(body)
|
|
1101
|
-
if import_str not in imports:
|
|
1102
|
-
imports.append(import_str)
|
|
1103
|
-
else:
|
|
1104
|
-
self._module_ast.body.remove(body)
|
|
1186
|
+
init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
|
|
1187
|
+
init_ast = ast.parse(init_code).body[0]
|
|
1188
|
+
AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
|
|
1189
|
+
# Insert ast to construct_function/class_internal_function
|
|
1190
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
1191
|
+
ast_functiondef = node_manager.get_ast_functiondef()
|
|
1192
|
+
if not ast_functiondef:
|
|
1193
|
+
raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
|
|
1194
|
+
"when inserting the ast.")
|
|
1195
|
+
AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
|
|
1105
1196
|
|
|
1106
1197
|
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1107
1198
|
if isinstance(node_or_name, str):
|
|
1108
1199
|
return self.get_node(node_or_name)
|
|
1109
1200
|
return node_or_name
|
|
1110
1201
|
|
|
1111
|
-
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
|
1112
|
-
"""
|
|
1113
|
-
Insert a node-tree into SymbolTree.
|
|
1114
|
-
Note:
|
|
1115
|
-
Inputs of intra sub-tree nodes need to be welly set.
|
|
1116
|
-
|
|
1117
|
-
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
|
1118
|
-
|
|
1119
|
-
Args:
|
|
1120
|
-
position (Position): A Position indicates an insert position point.
|
|
1121
|
-
root (Node): An instance of node as root of node-tree to be inserted in.
|
|
1122
|
-
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
1123
|
-
True.
|
|
1124
|
-
|
|
1125
|
-
Returns:
|
|
1126
|
-
An instance of node as root node of node-tree which has been inserted into SymbolTree.
|
|
1127
|
-
|
|
1128
|
-
Raises:
|
|
1129
|
-
RuntimeError: If 'position' is not in current SymbolTree.
|
|
1130
|
-
"""
|
|
1131
|
-
|
|
1132
|
-
# if position not in current SymbolTree
|
|
1133
|
-
if position.symbol_tree is not self:
|
|
1134
|
-
raise RuntimeError("Position is not in current SymbolTree: ", position)
|
|
1135
|
-
|
|
1136
|
-
queue: [Node] = [root]
|
|
1137
|
-
todos: [] = []
|
|
1138
|
-
inputs_list: [] = []
|
|
1139
|
-
while queue:
|
|
1140
|
-
cur_node = queue.pop(0)
|
|
1141
|
-
if cur_node in todos:
|
|
1142
|
-
continue
|
|
1143
|
-
todos.append(cur_node)
|
|
1144
|
-
node_inputs = cur_node.get_inputs()
|
|
1145
|
-
inputs_list.append(node_inputs)
|
|
1146
|
-
for node_input in node_inputs:
|
|
1147
|
-
if node_input is not None:
|
|
1148
|
-
queue.append(node_input)
|
|
1149
|
-
todos.reverse()
|
|
1150
|
-
inputs_list.reverse()
|
|
1151
|
-
for index, todo in enumerate(todos):
|
|
1152
|
-
self.insert_node(position, todo, insert_to_ast)
|
|
1153
|
-
position = self.after(todo)
|
|
1154
|
-
# relink input of node
|
|
1155
|
-
original_inputs = inputs_list[index]
|
|
1156
|
-
for arg_idx, original_input in enumerate(original_inputs):
|
|
1157
|
-
if original_input is not None:
|
|
1158
|
-
self.set_node_arg_by_node(todo, arg_idx, original_input)
|
|
1159
|
-
return root
|
|
1160
|
-
|
|
1161
|
-
def _add_node2nodes(self, node: Node):
|
|
1162
|
-
"""
|
|
1163
|
-
Add `node` to `_nodes` dict.
|
|
1164
|
-
|
|
1165
|
-
Args:
|
|
1166
|
-
node (Node): A Node to be added into `_nodes`.
|
|
1167
|
-
|
|
1168
|
-
Raises:
|
|
1169
|
-
RuntimeError: If name of the node is duplicated.
|
|
1170
|
-
"""
|
|
1171
|
-
node_name = node.get_name()
|
|
1172
|
-
if self._nodes.get(node_name) is not None:
|
|
1173
|
-
raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
|
|
1174
|
-
node)
|
|
1175
|
-
self._nodes[node_name] = node
|
|
1176
|
-
|
|
1177
|
-
def _insert_node(self, position: Optional[Position], node: Node):
|
|
1178
|
-
"""
|
|
1179
|
-
Insert a node into SymbolTree.
|
|
1180
|
-
1. Add `node` to `_nodes`.
|
|
1181
|
-
2. Insert `node` to node list(source code order).
|
|
1182
|
-
3. Update topological relation and update inputs of `node`.
|
|
1183
|
-
|
|
1184
|
-
Args:
|
|
1185
|
-
position ([Position, optional]): Indicates node insert position. Position is None when inserting first node
|
|
1186
|
-
of SymbolTree.
|
|
1187
|
-
node (Node): A Node to be inserted into SymbolTree.
|
|
1188
|
-
|
|
1189
|
-
Raises:
|
|
1190
|
-
RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
|
|
1191
|
-
unless inserting first node.
|
|
1192
|
-
"""
|
|
1193
|
-
if position is None:
|
|
1194
|
-
if self._nodes:
|
|
1195
|
-
raise RuntimeError("self._nodes should be empty")
|
|
1196
|
-
self._head = node
|
|
1197
|
-
else:
|
|
1198
|
-
if position.before_node:
|
|
1199
|
-
position.node.insert_before(node)
|
|
1200
|
-
else:
|
|
1201
|
-
position.node.insert_after(node)
|
|
1202
|
-
self._tail = node
|
|
1203
|
-
self._add_node2nodes(node)
|
|
1204
|
-
self._topo_mgr.on_insert_node(node)
|
|
1205
|
-
node.set_belong_symbol_tree(self)
|
|
1206
|
-
|
|
1207
1202
|
def _handle_custom_obj_in_normalized_args(self, node: Node):
|
|
1208
1203
|
"""
|
|
1209
|
-
Convert CustomObjValue type argument to NamingValue type argument by storing custom object
|
|
1204
|
+
Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
|
|
1210
1205
|
|
|
1211
1206
|
Args:
|
|
1212
1207
|
node (Node): A Node whose arguments and keyword arguments to be handled.
|
|
1213
1208
|
"""
|
|
1214
|
-
|
|
1215
|
-
for
|
|
1209
|
+
normalized_args: {str, ScopedValue} = {}
|
|
1210
|
+
for key, value in node.get_normalized_args().items():
|
|
1216
1211
|
if not isinstance(value, ScopedValue):
|
|
1217
1212
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
1218
1213
|
if value.type == ValueType.CustomObjValue:
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1214
|
+
# Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
|
|
1215
|
+
arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
|
|
1216
|
+
setattr(self._origin_network, arg_name, value.value)
|
|
1217
|
+
# Add new code to __init__(): self.arg_name = obj.arg_name
|
|
1218
|
+
new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
|
|
1219
|
+
self._init_func_ast.body.append(new_ast)
|
|
1220
|
+
# Modify node's normalized_args: CustomObjValue -> self.arg_name
|
|
1221
|
+
normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
|
|
1224
1222
|
else:
|
|
1225
|
-
|
|
1226
|
-
node.set_normalized_args(
|
|
1223
|
+
normalized_args[key] = value
|
|
1224
|
+
node.set_normalized_args(normalized_args)
|
|
1227
1225
|
|
|
1228
1226
|
def _get_cls_through_file(self):
|
|
1229
1227
|
"""
|
|
@@ -1235,12 +1233,14 @@ class SymbolTree(Observer, Observable):
|
|
|
1235
1233
|
Returns:
|
|
1236
1234
|
A class handle.
|
|
1237
1235
|
"""
|
|
1238
|
-
self._update_container()
|
|
1239
1236
|
file_path = os.getcwd()
|
|
1240
1237
|
file_path = os.path.join(file_path, "rewritten_network")
|
|
1241
1238
|
if not os.path.exists(file_path):
|
|
1242
|
-
|
|
1243
|
-
|
|
1239
|
+
try:
|
|
1240
|
+
os.mkdir(file_path, mode=0o700)
|
|
1241
|
+
except FileExistsError:
|
|
1242
|
+
pass
|
|
1243
|
+
file_name = f"{self._opt_cls_name}_{id(self)}.py"
|
|
1244
1244
|
network_file = os.path.join(file_path, file_name)
|
|
1245
1245
|
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1246
1246
|
source = self.get_code()
|
|
@@ -1277,21 +1277,6 @@ class SymbolTree(Observer, Observable):
|
|
|
1277
1277
|
self._modified = True
|
|
1278
1278
|
self.changed(event)
|
|
1279
1279
|
|
|
1280
|
-
def _update_container(self):
|
|
1281
|
-
"""Update instance of node in container."""
|
|
1282
|
-
for node in self.nodes():
|
|
1283
|
-
index = 0
|
|
1284
|
-
if node.get_node_type() == NodeType.CellContainer:
|
|
1285
|
-
for n in node.node_list:
|
|
1286
|
-
if not n.valid:
|
|
1287
|
-
continue
|
|
1288
|
-
if n.get_node_type() == NodeType.Tree:
|
|
1289
|
-
obj = n.symbol_tree.get_network()
|
|
1290
|
-
node.get_instance()[index] = obj
|
|
1291
|
-
else:
|
|
1292
|
-
node.get_instance()[index] = n.get_instance()
|
|
1293
|
-
index += 1
|
|
1294
|
-
|
|
1295
1280
|
def _cal_difference_set(self, input, other):
|
|
1296
1281
|
"""Calculate different set of two sets."""
|
|
1297
1282
|
set1 = set(input)
|
|
@@ -1313,43 +1298,3 @@ class SymbolTree(Observer, Observable):
|
|
|
1313
1298
|
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1314
1299
|
for p in primitives:
|
|
1315
1300
|
new_net._primitives[p] = self._origin_network._primitives[p]
|
|
1316
|
-
|
|
1317
|
-
def _create_call_function(self, func, targets, args, kwargs):
|
|
1318
|
-
"""
|
|
1319
|
-
Create a Node object and generate the execution code to insert into the source code.
|
|
1320
|
-
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
|
1321
|
-
|
|
1322
|
-
Args:
|
|
1323
|
-
func (FunctionType) - The function to be called.
|
|
1324
|
-
targets (list [str]) - indicates the output name. As the output of the node in the source code.
|
|
1325
|
-
args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
|
|
1326
|
-
code. The default value is None, which means there is no parameter input in the cell.
|
|
1327
|
-
kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
|
|
1328
|
-
input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
|
|
1329
|
-
code as the 'kwargs' in the statement expression. The default value is None, which means there is no
|
|
1330
|
-
'kwargs' input.
|
|
1331
|
-
|
|
1332
|
-
Returns:
|
|
1333
|
-
An instance of `Node`.
|
|
1334
|
-
"""
|
|
1335
|
-
if not isinstance(func, types.FunctionType):
|
|
1336
|
-
raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
|
|
1337
|
-
|
|
1338
|
-
_package = func.__globals__['__package__']
|
|
1339
|
-
func_name = ".".join([_package, func.__name__]) if _package else func.__name__
|
|
1340
|
-
|
|
1341
|
-
ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
|
|
1342
|
-
scope_targets = [ScopedValue.create_naming_value(targets[0])]
|
|
1343
|
-
scope_func = ScopedValue.create_naming_value(func_name, "")
|
|
1344
|
-
call_args = list()
|
|
1345
|
-
for arg in args:
|
|
1346
|
-
if isinstance(arg, Node):
|
|
1347
|
-
call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
|
|
1348
|
-
else:
|
|
1349
|
-
call_args.append(ScopedValue.create_variable_value(arg))
|
|
1350
|
-
call_kwargs = {}
|
|
1351
|
-
for k, v in kwargs.items():
|
|
1352
|
-
call_kwargs[k] = ScopedValue.create_variable_value(v)
|
|
1353
|
-
node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
|
|
1354
|
-
call_kwargs)
|
|
1355
|
-
return node
|