mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/rewrite/symbol_tree.py
CHANGED
|
@@ -14,29 +14,31 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
|
16
16
|
import stat
|
|
17
|
-
from typing import Optional, Union, Tuple, Any
|
|
17
|
+
from typing import Optional, Union, Tuple, Any, Dict, List
|
|
18
18
|
import os
|
|
19
19
|
import sys
|
|
20
20
|
import ast
|
|
21
21
|
import importlib.util
|
|
22
|
-
import types
|
|
23
22
|
import time
|
|
24
|
-
import astunparse
|
|
25
23
|
|
|
26
24
|
from mindspore.nn import Cell
|
|
27
25
|
from mindspore import log as logger
|
|
28
|
-
from
|
|
29
|
-
from .node import Node, TreeNode
|
|
26
|
+
from .node.node import Node, TreeNode
|
|
30
27
|
from .api.node_type import NodeType
|
|
31
|
-
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder,
|
|
28
|
+
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
|
|
32
29
|
from .api.scoped_value import ScopedValue, ValueType
|
|
33
30
|
from .symbol_tree_dumper import SymbolTreeDumper
|
|
34
|
-
from .
|
|
31
|
+
from .node.node_topological_manager import TopoManager
|
|
35
32
|
from .namer import TargetNamer, NodeNamer, ClassNamer
|
|
36
33
|
from .common.observer import Observer
|
|
37
34
|
from .common.observable import Observable
|
|
38
35
|
from .common.event import Event
|
|
36
|
+
from .node.node_manager import NodeManager
|
|
39
37
|
|
|
38
|
+
if sys.version_info >= (3, 9):
|
|
39
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
40
|
+
else:
|
|
41
|
+
import astunparse
|
|
40
42
|
|
|
41
43
|
class Position:
|
|
42
44
|
"""
|
|
@@ -80,6 +82,7 @@ class FieldFinder(AstFinder):
|
|
|
80
82
|
Args:
|
|
81
83
|
scope (ast.AST): An instance of ast node as search scope.
|
|
82
84
|
"""
|
|
85
|
+
|
|
83
86
|
def __init__(self, scope: ast.AST):
|
|
84
87
|
super().__init__(scope)
|
|
85
88
|
self._result = False
|
|
@@ -133,7 +136,7 @@ class IfFixer(ast.NodeTransformer):
|
|
|
133
136
|
self.generic_visit(node)
|
|
134
137
|
|
|
135
138
|
|
|
136
|
-
class SymbolTree(Observer, Observable):
|
|
139
|
+
class SymbolTree(Observer, Observable, NodeManager):
|
|
137
140
|
"""
|
|
138
141
|
A symbol-tree usually corresponding to forward method of a network.
|
|
139
142
|
|
|
@@ -146,147 +149,138 @@ class SymbolTree(Observer, Observable):
|
|
|
146
149
|
"""
|
|
147
150
|
|
|
148
151
|
def __init__(self, origin_network: Cell, module_ast: ast.Module):
|
|
149
|
-
|
|
152
|
+
Observer.__init__(self)
|
|
150
153
|
Observable.__init__(self)
|
|
151
|
-
|
|
154
|
+
self._node_namer = NodeNamer()
|
|
155
|
+
self._node_namer.add_name('obj')
|
|
156
|
+
NodeManager.__init__(self, self._node_namer)
|
|
157
|
+
NodeManager.reg_observer(self, observer=self)
|
|
152
158
|
# init unique-namers
|
|
153
159
|
self._target_namer = TargetNamer()
|
|
154
|
-
|
|
155
|
-
# name or node would use as name of field, so name of origin network handler field should be added into \
|
|
156
|
-
# _node_name_namer.
|
|
157
|
-
self._node_name_namer.add_name(origin_network_key)
|
|
158
|
-
self._topo_mgr = TopoManager(self)
|
|
159
|
-
self._topo_mgr.reg_observer(self)
|
|
160
|
-
|
|
161
|
-
self._nodes: {str, Node} = {}
|
|
162
|
-
# parameters of forward method
|
|
163
|
-
self._inputs: [Node] = []
|
|
160
|
+
# input arguments of function
|
|
164
161
|
self._ori_cls_name = type(origin_network).__name__
|
|
165
162
|
self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
|
|
163
|
+
NodeManager.set_manager_name(self, self._opt_cls_name)
|
|
166
164
|
self._origin_network = origin_network
|
|
167
165
|
self._module_ast: ast.Module = module_ast
|
|
166
|
+
self._import_asts: Optional[ast.Ast] = []
|
|
168
167
|
self._class_ast: Optional[ast.ClassDef] = None
|
|
169
168
|
self._root_ast: Optional[ast.FunctionDef] = None
|
|
170
169
|
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
171
170
|
self._deleted_field = {}
|
|
172
171
|
self._deleted_node = []
|
|
173
|
-
self.
|
|
172
|
+
self._external_ast = []
|
|
174
173
|
self._father_class_ast = []
|
|
175
|
-
|
|
176
|
-
# head node is always point to the first node(in source code order) of SymbolTree
|
|
177
|
-
self._head = None
|
|
178
|
-
# tail node is always point to the last node(in source code order) of SymbolTree
|
|
179
|
-
self._tail = None
|
|
180
|
-
self._return: Optional[Node] = None
|
|
181
|
-
|
|
182
174
|
self._modified = False
|
|
183
|
-
self._node_visitor = None
|
|
184
|
-
|
|
185
175
|
self._tmp_file_limits = 20
|
|
186
176
|
self._tmp_files = []
|
|
187
177
|
self._saved_file_name = "./network_define.py"
|
|
188
178
|
# used to insert "sys.path.append(xxx)"
|
|
189
179
|
self._net_file_paths = []
|
|
180
|
+
self._tmp_import_strs = []
|
|
181
|
+
self._tmp_unmodified_strees: {type, str} = {}
|
|
182
|
+
self._tmp_replacers = []
|
|
183
|
+
# Record imported modules and names of each files
|
|
184
|
+
# The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
|
|
185
|
+
# Format: {file_path: {module: [name, ...], ...}, ...}
|
|
186
|
+
self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
|
|
190
187
|
|
|
191
188
|
def __del__(self):
|
|
192
189
|
for tmp_file in self._tmp_files:
|
|
193
190
|
tmp_file.close()
|
|
194
191
|
|
|
195
192
|
@staticmethod
|
|
196
|
-
def
|
|
197
|
-
"""
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
for node in nodes:
|
|
203
|
-
for arg in node.get_args():
|
|
204
|
-
if consumers.get(arg):
|
|
205
|
-
consumers[arg].append(node)
|
|
206
|
-
else:
|
|
207
|
-
consumers[arg] = [node]
|
|
208
|
-
for _, arg in node.get_kwargs():
|
|
209
|
-
if consumers.get(arg):
|
|
210
|
-
consumers[arg].append(node)
|
|
211
|
-
else:
|
|
212
|
-
consumers[arg] = [node]
|
|
213
|
-
for target in node.get_targets():
|
|
214
|
-
if providers.get(target) is not None:
|
|
215
|
-
raise RuntimeError(f"Target({target}) of node duplicated")
|
|
216
|
-
providers[target] = node
|
|
217
|
-
return consumers, providers
|
|
218
|
-
|
|
219
|
-
@staticmethod
|
|
220
|
-
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
|
|
221
|
-
"""Find all non-duplicated class name of SymbolTree recursively."""
|
|
222
|
-
replacer = AstReplacer(stree.get_class_ast())
|
|
223
|
-
replacers.append(replacer)
|
|
224
|
-
for node in stree.nodes():
|
|
225
|
-
if not isinstance(node, TreeNode):
|
|
193
|
+
def _remove_unused_import(module_ast):
|
|
194
|
+
"""remove unused import in self._module_ast"""
|
|
195
|
+
str_checker = StrChecker(module_ast)
|
|
196
|
+
for i in range(len(module_ast.body) - 1, -1, -1):
|
|
197
|
+
body = module_ast.body[i]
|
|
198
|
+
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
|
226
199
|
continue
|
|
227
|
-
if
|
|
200
|
+
if isinstance(body, ast.Import):
|
|
228
201
|
continue
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
# all modified ast.ClassDef should export to code
|
|
232
|
-
if sub_stree._modified:
|
|
233
|
-
allow_class_name.append(sub_stree._class_ast.name)
|
|
202
|
+
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
|
203
|
+
module_ast.body.remove(body)
|
|
234
204
|
continue
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
205
|
+
for alias in body.names:
|
|
206
|
+
name = alias.asname if alias.asname else alias.name
|
|
207
|
+
if not str_checker.check(name):
|
|
208
|
+
if len(body.names) == 1:
|
|
209
|
+
module_ast.body.remove(body)
|
|
210
|
+
i += 1
|
|
211
|
+
else:
|
|
212
|
+
body.names.remove(alias)
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _remove_duplicated_import(module_ast):
|
|
216
|
+
"""Remove duplicated import of 'net'."""
|
|
217
|
+
imports = set()
|
|
218
|
+
futures = set()
|
|
219
|
+
classes = set()
|
|
220
|
+
|
|
221
|
+
class TransImportNode(ast.NodeTransformer):
|
|
222
|
+
"""Find all import nodes from input ast node."""
|
|
223
|
+
|
|
224
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
225
|
+
class_str = astunparse.unparse(node)
|
|
226
|
+
if class_str not in classes:
|
|
227
|
+
classes.add(node.name)
|
|
228
|
+
return node
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
def visit_Try(self, node: ast.Try) -> Any:
|
|
232
|
+
if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
|
|
233
|
+
import_str = astunparse.unparse(node)
|
|
234
|
+
if import_str not in imports:
|
|
235
|
+
imports.add(import_str)
|
|
236
|
+
return node
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
def visit_Import(self, node: ast.Import) -> Any:
|
|
240
|
+
import_str = astunparse.unparse(node)
|
|
241
|
+
if import_str not in imports:
|
|
242
|
+
imports.add(import_str)
|
|
243
|
+
return node
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
247
|
+
"""
|
|
248
|
+
Once the father class 'A' is defined in the current module, all the next imported class 'A' should
|
|
249
|
+
be removed. e.g.
|
|
250
|
+
def class A():
|
|
251
|
+
...
|
|
252
|
+
from xxx import A, B
|
|
253
|
+
=>
|
|
254
|
+
def class A():
|
|
255
|
+
...
|
|
256
|
+
from xxx import B
|
|
257
|
+
"""
|
|
258
|
+
import_str = astunparse.unparse(node)
|
|
259
|
+
|
|
260
|
+
if import_str not in imports:
|
|
261
|
+
imports.add(import_str)
|
|
262
|
+
# remove "__future__" module
|
|
263
|
+
if node.module == '__future__':
|
|
264
|
+
futures.add(node.module)
|
|
265
|
+
return
|
|
266
|
+
# remove modules which have been defined in the code file
|
|
267
|
+
# it occurs when class A is a father class and other sub-classes import A
|
|
268
|
+
for alias in node.names[:]:
|
|
269
|
+
if alias.name in classes:
|
|
270
|
+
node.names.remove(alias)
|
|
271
|
+
# if the alias(es) in node.names are all removed, this import statement should be removed
|
|
272
|
+
if not node.names:
|
|
273
|
+
return
|
|
274
|
+
return node
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
get_node_handler = TransImportNode()
|
|
278
|
+
get_node_handler.generic_visit(module_ast)
|
|
242
279
|
|
|
243
280
|
def finish_build(self):
|
|
244
281
|
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
245
282
|
self.add_event(Event.TopologicalChangeEvent)
|
|
246
283
|
|
|
247
|
-
def create_assign_node(self, targets, func_name, args, kwargs):
|
|
248
|
-
"""
|
|
249
|
-
Create a ast.Assign type node.
|
|
250
|
-
|
|
251
|
-
Args:
|
|
252
|
-
targets (list): _description_
|
|
253
|
-
func_name (_type_): _description_
|
|
254
|
-
args (_type_): _description_
|
|
255
|
-
kwargs (_type_): _description_
|
|
256
|
-
|
|
257
|
-
Returns:
|
|
258
|
-
_type_: _description_
|
|
259
|
-
"""
|
|
260
|
-
# create targets
|
|
261
|
-
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
262
|
-
# create call
|
|
263
|
-
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
264
|
-
ast_args = ast_creator_registry.get("Args")(args)
|
|
265
|
-
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
266
|
-
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
267
|
-
# create assign
|
|
268
|
-
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
269
|
-
return ast_node
|
|
270
|
-
|
|
271
|
-
def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
|
|
272
|
-
'''
|
|
273
|
-
Instantiate an instance of node whose type is `CallFunction`.
|
|
274
|
-
|
|
275
|
-
Args:
|
|
276
|
-
node_name (str): Name of node.
|
|
277
|
-
func_name (str): Name of function.
|
|
278
|
-
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
279
|
-
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
280
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
281
|
-
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
282
|
-
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
283
|
-
class.
|
|
284
|
-
'''
|
|
285
|
-
logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
|
|
286
|
-
node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
|
|
287
|
-
node.set_belong_symbol_tree(self)
|
|
288
|
-
return node
|
|
289
|
-
|
|
290
284
|
def get_ori_cls_name(self) -> str:
|
|
291
285
|
"""
|
|
292
286
|
Get class name of original network.
|
|
@@ -342,6 +336,7 @@ class SymbolTree(Observer, Observable):
|
|
|
342
336
|
corresponding network class.
|
|
343
337
|
"""
|
|
344
338
|
self._root_ast = ast_node
|
|
339
|
+
NodeManager.set_ast_functiondef(self, ast_node)
|
|
345
340
|
|
|
346
341
|
def get_class_ast(self):
|
|
347
342
|
"""
|
|
@@ -380,18 +375,6 @@ class SymbolTree(Observer, Observable):
|
|
|
380
375
|
"""
|
|
381
376
|
self._init_func_ast = ast_node
|
|
382
377
|
|
|
383
|
-
def get_inputs(self):
|
|
384
|
-
return self._inputs
|
|
385
|
-
|
|
386
|
-
def get_head_node(self):
|
|
387
|
-
"""
|
|
388
|
-
Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
|
|
389
|
-
|
|
390
|
-
Returns:
|
|
391
|
-
An instance of node.
|
|
392
|
-
"""
|
|
393
|
-
return self._head
|
|
394
|
-
|
|
395
378
|
def get_origin_network(self):
|
|
396
379
|
"""
|
|
397
380
|
Getter of `_origin_network`.
|
|
@@ -405,46 +388,53 @@ class SymbolTree(Observer, Observable):
|
|
|
405
388
|
"""Get dict of nodes"""
|
|
406
389
|
return self._nodes
|
|
407
390
|
|
|
408
|
-
def
|
|
409
|
-
"""Get
|
|
410
|
-
return self.
|
|
391
|
+
def get_node_namer(self):
|
|
392
|
+
"""Get _node_namer"""
|
|
393
|
+
return self._node_namer
|
|
411
394
|
|
|
412
|
-
def
|
|
413
|
-
"""
|
|
414
|
-
|
|
415
|
-
self._net_file_paths.append(file_path)
|
|
416
|
-
|
|
417
|
-
def get_net_file_path(self):
|
|
418
|
-
"""Get _net_file_paths"""
|
|
419
|
-
return self._net_file_paths
|
|
395
|
+
def is_modified(self):
|
|
396
|
+
"""
|
|
397
|
+
Check whether symbol tree is modified.
|
|
420
398
|
|
|
421
|
-
|
|
399
|
+
Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
|
|
400
|
+
the symbol tree is created.
|
|
422
401
|
"""
|
|
423
|
-
|
|
402
|
+
return self._modified
|
|
424
403
|
|
|
425
|
-
|
|
426
|
-
A generator for iterating Nodes of `SymbolTree`.
|
|
404
|
+
def set_modified_true(self):
|
|
427
405
|
"""
|
|
428
|
-
|
|
429
|
-
nodes = []
|
|
430
|
-
node = self._head
|
|
431
|
-
while node is not None:
|
|
432
|
-
nodes.append(node)
|
|
433
|
-
node = node.get_next()
|
|
434
|
-
return iter(nodes)
|
|
406
|
+
Set self._modified true.
|
|
435
407
|
|
|
436
|
-
|
|
408
|
+
Self._modified is set true when 'if' exists in the original network.
|
|
409
|
+
In this situation, different original network instance tends to be different.
|
|
410
|
+
Hence, the class name should be updated.
|
|
437
411
|
"""
|
|
438
|
-
|
|
412
|
+
self._modified = True
|
|
439
413
|
|
|
440
|
-
|
|
441
|
-
|
|
414
|
+
def get_import_asts(self):
|
|
415
|
+
"""Get _import_asts"""
|
|
416
|
+
return self._import_asts
|
|
442
417
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
418
|
+
def get_external_ast(self):
|
|
419
|
+
"""Get _external_ast"""
|
|
420
|
+
return self._external_ast
|
|
421
|
+
|
|
422
|
+
def get_father_class_ast(self):
|
|
423
|
+
"""Get _father_class_ast"""
|
|
424
|
+
return self._father_class_ast
|
|
425
|
+
|
|
426
|
+
def get_imported_modules(self, file_path: str):
|
|
427
|
+
"""Get all modules and module_paths in file of `file_path` ."""
|
|
428
|
+
return self._imported_modules.get(file_path, {})
|
|
446
429
|
|
|
447
|
-
|
|
430
|
+
def save_imported_modules(self, file_path: str, module: str, names: List[str]):
|
|
431
|
+
"""Save module and names into _imported_modules."""
|
|
432
|
+
imported_modules = self.get_imported_modules(file_path)
|
|
433
|
+
if imported_modules.get(module):
|
|
434
|
+
imported_modules[module].extend(names)
|
|
435
|
+
else:
|
|
436
|
+
imported_modules[module] = names
|
|
437
|
+
self._imported_modules[file_path] = imported_modules
|
|
448
438
|
|
|
449
439
|
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
|
450
440
|
"""
|
|
@@ -535,9 +525,11 @@ class SymbolTree(Observer, Observable):
|
|
|
535
525
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
536
526
|
return Position.create(node.get_belong_symbol_tree(), node, False)
|
|
537
527
|
|
|
538
|
-
def insert_node(self,
|
|
528
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
|
|
529
|
+
insert_to_ast: bool = True):
|
|
539
530
|
"""
|
|
540
|
-
Insert a node
|
|
531
|
+
Insert a node before or after base_node.
|
|
532
|
+
|
|
541
533
|
Note:
|
|
542
534
|
Name of node will be unique while inserting node into SymbolTree.
|
|
543
535
|
|
|
@@ -556,57 +548,73 @@ class SymbolTree(Observer, Observable):
|
|
|
556
548
|
Topological relation is updated and inputs of corresponding node is updated.
|
|
557
549
|
|
|
558
550
|
Args:
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
551
|
+
new_node (Node): Node to be inserted.
|
|
552
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
553
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
554
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
555
|
+
NodeManager of symboltree's construct function.
|
|
556
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
563
557
|
|
|
564
558
|
Returns:
|
|
565
559
|
An instance of node which has been inserted into SymbolTree.
|
|
566
560
|
|
|
567
561
|
Raises:
|
|
568
562
|
ValueError: Node in the SymbolTree is inserted into SymbolTree again.
|
|
569
|
-
RuntimeError: If 'position' is not in current SymbolTree.
|
|
570
563
|
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
|
571
564
|
"""
|
|
572
|
-
if
|
|
573
|
-
raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
if
|
|
565
|
+
if new_node.get_belong_symbol_tree():
|
|
566
|
+
raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
|
|
567
|
+
|
|
568
|
+
# Check if base_node in current SymbolTree
|
|
569
|
+
if base_node is not None:
|
|
570
|
+
stree = base_node.get_belong_symbol_tree()
|
|
571
|
+
if stree is not None and stree is not self:
|
|
572
|
+
raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
|
|
573
|
+
f"current: {self.get_ori_cls_name()}.")
|
|
574
|
+
|
|
575
|
+
# Check if node is inserted between Input node
|
|
576
|
+
if base_node is not None and base_node.get_node_type() == NodeType.Input:
|
|
584
577
|
valid = True
|
|
585
|
-
if
|
|
578
|
+
if before_node:
|
|
586
579
|
valid = False
|
|
587
|
-
if
|
|
580
|
+
if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
|
|
588
581
|
valid = False
|
|
589
582
|
if not valid:
|
|
590
|
-
raise RuntimeError("Can not insert a node before or between parameters:",
|
|
591
|
-
|
|
592
|
-
node_name = self._node_name_namer.get_name(node)
|
|
593
|
-
node.set_name(node_name)
|
|
583
|
+
raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
|
|
584
|
+
|
|
594
585
|
# save target name, which is used to provide unique target
|
|
595
|
-
if
|
|
596
|
-
for target in
|
|
586
|
+
if new_node.get_targets():
|
|
587
|
+
for target in new_node.get_targets():
|
|
597
588
|
self._target_namer.add_name(str(target))
|
|
598
|
-
self._handle_custom_obj_in_normalized_args(node)
|
|
599
|
-
self._insert_node(position, node)
|
|
600
|
-
if isinstance(node, TreeNode):
|
|
601
|
-
node.symbol_tree.reg_observer(self)
|
|
602
|
-
if self._node_visitor:
|
|
603
|
-
self._node_visitor.append_node(node)
|
|
604
|
-
# update init-function-ast and construct-function-ast
|
|
605
|
-
if insert_to_ast:
|
|
606
|
-
self._insert_to_ast_while_insert_node(node, position)
|
|
607
|
-
return node
|
|
608
589
|
|
|
609
|
-
|
|
590
|
+
self._handle_custom_obj_in_normalized_args(new_node)
|
|
591
|
+
|
|
592
|
+
# Insert node into NodeManager
|
|
593
|
+
if node_manager is None:
|
|
594
|
+
if base_node is None:
|
|
595
|
+
raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
|
|
596
|
+
node_manager = base_node.get_node_manager()
|
|
597
|
+
|
|
598
|
+
# set node's _belong_symbol_tree
|
|
599
|
+
new_node.set_belong_symbol_tree(self)
|
|
600
|
+
|
|
601
|
+
if node_manager is self:
|
|
602
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
603
|
+
if insert_to_ast:
|
|
604
|
+
# update init-function-ast and construct-function-ast
|
|
605
|
+
self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
|
|
606
|
+
else:
|
|
607
|
+
node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
|
|
608
|
+
|
|
609
|
+
# register code changed event observer, which is used to update _modified flag.
|
|
610
|
+
if new_node.get_node_type() == NodeType.Tree:
|
|
611
|
+
new_node.symbol_tree.reg_observer(self)
|
|
612
|
+
elif isinstance(new_node, NodeManager):
|
|
613
|
+
new_node.reg_observer(self)
|
|
614
|
+
|
|
615
|
+
return new_node
|
|
616
|
+
|
|
617
|
+
def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
|
|
610
618
|
"""
|
|
611
619
|
Append a node to SymbolTree.
|
|
612
620
|
|
|
@@ -614,13 +622,17 @@ class SymbolTree(Observer, Observable):
|
|
|
614
622
|
node (Node): An instance of node to be appended.
|
|
615
623
|
append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
616
624
|
True.
|
|
625
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
626
|
+
NodeManager of symboltree's construct function.
|
|
617
627
|
|
|
618
628
|
Returns:
|
|
619
629
|
An instance of node which has been appended to SymbolTree.
|
|
620
630
|
"""
|
|
621
|
-
|
|
631
|
+
if node_manager is None:
|
|
632
|
+
node_manager = self
|
|
633
|
+
return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
|
|
622
634
|
|
|
623
|
-
def append_origin_field(self, node: Node) -> Node:
|
|
635
|
+
def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
|
|
624
636
|
"""
|
|
625
637
|
Append an original field node to SymbolTree. An original field node represents a node created from existing
|
|
626
638
|
statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
|
|
@@ -629,26 +641,16 @@ class SymbolTree(Observer, Observable):
|
|
|
629
641
|
|
|
630
642
|
Args:
|
|
631
643
|
node (Node): An instance of node to be appended.
|
|
644
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
645
|
+
NodeManager of symboltree's construct function.
|
|
632
646
|
|
|
633
647
|
Returns:
|
|
634
648
|
An instance of node which has been appended to SymbolTree.
|
|
635
649
|
"""
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
elif node.get_node_type() == NodeType.Tree:
|
|
641
|
-
# add father_class_ast into main tree, used when get_code
|
|
642
|
-
for father_ast in node.symbol_tree.get_father_class_ast():
|
|
643
|
-
if father_ast not in self._father_class_ast:
|
|
644
|
-
self._father_class_ast.append(father_ast)
|
|
645
|
-
# add subtree's net path into main tree
|
|
646
|
-
for file_path in node.symbol_tree.get_net_file_path():
|
|
647
|
-
if file_path not in self._net_file_paths:
|
|
648
|
-
self.append_net_file_path(file_path)
|
|
649
|
-
return self.append_node(node, False)
|
|
650
|
-
|
|
651
|
-
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None):
|
|
650
|
+
return self.append_node(node, node_manager, False)
|
|
651
|
+
|
|
652
|
+
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
|
|
653
|
+
node_manager: NodeManager = None):
|
|
652
654
|
"""
|
|
653
655
|
Append an input node to SymbolTree corresponding to parameter of forward method of network class.
|
|
654
656
|
This method is called while building SymbolTree usually.
|
|
@@ -658,13 +660,18 @@ class SymbolTree(Observer, Observable):
|
|
|
658
660
|
param_name (str): A str represents name of parameter of forward method of network class.
|
|
659
661
|
default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
|
|
660
662
|
means parameter has no default value.
|
|
663
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
664
|
+
NodeManager of symboltree's construct function.
|
|
661
665
|
|
|
662
666
|
Returns:
|
|
663
667
|
An instance of input node which has been appended to SymbolTree.
|
|
664
668
|
"""
|
|
665
669
|
if param_name == "self":
|
|
666
670
|
return
|
|
667
|
-
|
|
671
|
+
# check param_name duplicated
|
|
672
|
+
if node_manager is None:
|
|
673
|
+
node_manager = self
|
|
674
|
+
for input_node in node_manager._inputs:
|
|
668
675
|
targets = input_node.get_targets()
|
|
669
676
|
if len(targets) != 1:
|
|
670
677
|
raise RuntimeError("targets should have 1 elements")
|
|
@@ -677,9 +684,10 @@ class SymbolTree(Observer, Observable):
|
|
|
677
684
|
if exist_param == param_name:
|
|
678
685
|
raise RuntimeError("input duplicated:", param_name)
|
|
679
686
|
input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
|
|
680
|
-
self.append_origin_field(input_node)
|
|
687
|
+
self.append_origin_field(input_node, node_manager)
|
|
681
688
|
|
|
682
|
-
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST
|
|
689
|
+
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
|
|
690
|
+
node_manager: NodeManager = None) -> Optional[Node]:
|
|
683
691
|
"""
|
|
684
692
|
Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
|
|
685
693
|
a list or a dict.
|
|
@@ -688,6 +696,8 @@ class SymbolTree(Observer, Observable):
|
|
|
688
696
|
Args:
|
|
689
697
|
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
690
698
|
ast_node (ast.AST): A ast node represents ast node.
|
|
699
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
700
|
+
NodeManager of symboltree's construct function.
|
|
691
701
|
|
|
692
702
|
Returns:
|
|
693
703
|
An instance of python node if a new node has been appended to SymbolTree else None.
|
|
@@ -696,9 +706,9 @@ class SymbolTree(Observer, Observable):
|
|
|
696
706
|
return None
|
|
697
707
|
if isinstance(ast_node, (list, dict)) and not ast_node:
|
|
698
708
|
return None
|
|
699
|
-
return self.append_python_node(ast_scope, ast_node)
|
|
709
|
+
return self.append_python_node(ast_scope, ast_node, node_manager)
|
|
700
710
|
|
|
701
|
-
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
|
|
711
|
+
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
|
|
702
712
|
"""
|
|
703
713
|
Append a python node to SymbolTree.
|
|
704
714
|
This method is called while building SymbolTree usually.
|
|
@@ -706,39 +716,50 @@ class SymbolTree(Observer, Observable):
|
|
|
706
716
|
Args:
|
|
707
717
|
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
708
718
|
ast_node (ast.AST): A ast node represents ast node.
|
|
719
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
720
|
+
NodeManager of symboltree's construct function.
|
|
709
721
|
|
|
710
722
|
Returns:
|
|
711
723
|
An instance of python node which has been appended to SymbolTree.
|
|
712
724
|
"""
|
|
713
725
|
logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
|
|
714
|
-
node_name =
|
|
726
|
+
node_name = type(ast_node).__name__
|
|
715
727
|
node = Node.create_python_node(ast_node, node_name)
|
|
716
|
-
|
|
728
|
+
if node_manager is None or node_manager is self:
|
|
729
|
+
NodeManager.append_python_node(self, node)
|
|
730
|
+
else:
|
|
731
|
+
node_manager.append_python_node(node)
|
|
717
732
|
return node
|
|
718
733
|
|
|
719
|
-
def set_output(self, return_value: str,
|
|
734
|
+
def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
|
|
735
|
+
node_manager: NodeManager = None) -> Node:
|
|
720
736
|
"""
|
|
721
737
|
Update return value of return of forward method of network class.
|
|
722
738
|
|
|
723
739
|
Args:
|
|
724
740
|
return_value (str): A str represents new return value.
|
|
725
|
-
|
|
741
|
+
arg_index (int): A int indicates which value in return to be updated.
|
|
742
|
+
return_idx (int): A int indicates which return node to be updated. Default: 0.
|
|
743
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
|
|
744
|
+
symboltree's construct function.
|
|
726
745
|
|
|
727
746
|
Returns:
|
|
728
747
|
An instance of node represents return node after updated.
|
|
729
748
|
"""
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
749
|
+
node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
|
|
750
|
+
if not node_returns:
|
|
751
|
+
raise RuntimeError("Current node_manager has no output")
|
|
752
|
+
if return_idx >= len(node_returns):
|
|
753
|
+
raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
|
|
754
|
+
node_return = node_returns[return_idx]
|
|
755
|
+
self.set_node_arg(node_return, arg_index, return_value)
|
|
756
|
+
return node_return
|
|
734
757
|
|
|
735
758
|
def erase_node(self, node_or_name: Union[Node, str]) -> Node:
|
|
736
759
|
"""
|
|
737
760
|
Erase a node from SymbolTree.
|
|
738
|
-
Note:
|
|
739
|
-
If node is depended on by other node, RuntimeError will raise.
|
|
740
761
|
|
|
741
|
-
|
|
762
|
+
Topological relation will be updated.
|
|
742
763
|
|
|
743
764
|
Args:
|
|
744
765
|
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
@@ -754,19 +775,21 @@ class SymbolTree(Observer, Observable):
|
|
|
754
775
|
node = self._get_real_node(node_or_name)
|
|
755
776
|
if node is None:
|
|
756
777
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
778
|
+
# erase node in NodeManager
|
|
779
|
+
node_manager = node.get_node_manager()
|
|
780
|
+
|
|
781
|
+
logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
|
|
782
|
+
f"node_manager: {node_manager.get_manager_name()}, "
|
|
783
|
+
f"code: {astunparse.unparse(node.get_ast()).strip()}, "
|
|
784
|
+
f"node_name:{node.get_name()}")
|
|
785
|
+
|
|
786
|
+
if node_manager is self:
|
|
787
|
+
NodeManager.erase_node(self, node)
|
|
788
|
+
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
789
|
+
if not ret:
|
|
790
|
+
raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
791
|
+
else:
|
|
792
|
+
node_manager.erase_node(node)
|
|
770
793
|
self._deleted_node.append(node.get_name())
|
|
771
794
|
return node
|
|
772
795
|
|
|
@@ -785,25 +808,16 @@ class SymbolTree(Observer, Observable):
|
|
|
785
808
|
RuntimeError: If 'old_node' is isolated.
|
|
786
809
|
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
|
787
810
|
"""
|
|
788
|
-
|
|
789
|
-
if hasattr(old_node, "container"):
|
|
790
|
-
self._replace_container_node(old_node, new_nodes)
|
|
791
|
-
return new_nodes[0]
|
|
792
811
|
real_old_node = self._get_real_node(old_node)
|
|
793
812
|
if real_old_node is None:
|
|
794
813
|
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
|
795
|
-
#
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
raise RuntimeError("Try replacing a isolated node: ", old_node)
|
|
800
|
-
if prev_node is None:
|
|
801
|
-
position = self.before(next_node)
|
|
802
|
-
else:
|
|
803
|
-
position = self.after(prev_node)
|
|
814
|
+
# insert new_nodes into node_manager
|
|
815
|
+
node_manager = real_old_node.get_node_manager()
|
|
816
|
+
# insert new_nodes into NodeManager
|
|
817
|
+
base_node = old_node
|
|
804
818
|
for node in new_nodes:
|
|
805
|
-
self.insert_node(
|
|
806
|
-
|
|
819
|
+
self.insert_node(node, base_node, False, node_manager, True)
|
|
820
|
+
base_node = node
|
|
807
821
|
self.erase_node(old_node)
|
|
808
822
|
return new_nodes[-1]
|
|
809
823
|
|
|
@@ -868,6 +882,15 @@ class SymbolTree(Observer, Observable):
|
|
|
868
882
|
"""Get a unique name in the symboltree"""
|
|
869
883
|
return self._target_namer.get_name(name)
|
|
870
884
|
|
|
885
|
+
def unique_func_name(self, name: str):
|
|
886
|
+
"""Get a unique function name in the symboltree"""
|
|
887
|
+
if not hasattr(self._origin_network, name):
|
|
888
|
+
return name
|
|
889
|
+
suffix = 1
|
|
890
|
+
while hasattr(self._origin_network, f"{name}_{suffix}"):
|
|
891
|
+
suffix += 1
|
|
892
|
+
return f"{name}_{suffix}"
|
|
893
|
+
|
|
871
894
|
def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
|
|
872
895
|
"""
|
|
873
896
|
Set target of `node` .
|
|
@@ -895,21 +918,191 @@ class SymbolTree(Observer, Observable):
|
|
|
895
918
|
node.set_targets(targets)
|
|
896
919
|
self._topo_mgr.on_update_target(node, index, old_target, target)
|
|
897
920
|
|
|
898
|
-
def
|
|
899
|
-
|
|
921
|
+
def all_nodes(self):
|
|
922
|
+
"""
|
|
923
|
+
Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
A list of nodes.
|
|
927
|
+
"""
|
|
928
|
+
nodes = []
|
|
929
|
+
node_managers = [self]
|
|
930
|
+
while node_managers:
|
|
931
|
+
node_manager = node_managers.pop()
|
|
932
|
+
nodes.extend(node_manager.nodes())
|
|
933
|
+
for node in node_manager.nodes():
|
|
934
|
+
if isinstance(node, NodeManager):
|
|
935
|
+
node_managers.append(node)
|
|
936
|
+
for tree_node in self.get_tree_nodes():
|
|
937
|
+
stree = tree_node.symbol_tree
|
|
938
|
+
nodes.extend(stree.all_nodes())
|
|
939
|
+
return nodes
|
|
940
|
+
|
|
941
|
+
def get_node_from_name(self, node_name: str):
|
|
942
|
+
"""
|
|
943
|
+
Get node from all NodeManagers in current symbol tree by `node_name`.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
node_name (str): A str represents name of node as key of query.
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
An instance of Node if found else None.
|
|
950
|
+
"""
|
|
951
|
+
node_managers = [self]
|
|
952
|
+
while node_managers:
|
|
953
|
+
node_manager = node_managers.pop()
|
|
954
|
+
node = node_manager.get_node(node_name)
|
|
955
|
+
if node:
|
|
956
|
+
return node
|
|
957
|
+
for node in node_manager.nodes():
|
|
958
|
+
if isinstance(node, NodeManager):
|
|
959
|
+
node_managers.append(node)
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
def print_node_tabulate(self, all_nodes: bool = False):
|
|
963
|
+
"""
|
|
964
|
+
Print nodes information and nodes' topological relations.
|
|
965
|
+
|
|
966
|
+
Args:
|
|
967
|
+
all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
|
|
968
|
+
nodes, CellContainer nodes and sub symbol trees.
|
|
969
|
+
"""
|
|
970
|
+
try:
|
|
971
|
+
from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
|
|
972
|
+
except ImportError:
|
|
973
|
+
logger.warning("print_node_tabulate relies on the library `tabulate`, "
|
|
974
|
+
"which could not be found on this machine. Run `pip "
|
|
975
|
+
"install tabulate` to install the library.")
|
|
976
|
+
return ""
|
|
977
|
+
print(NodeManager.dump(self, self.get_manager_name()))
|
|
978
|
+
if all_nodes:
|
|
979
|
+
node_managers = [self]
|
|
980
|
+
while node_managers:
|
|
981
|
+
node_manager = node_managers.pop()
|
|
982
|
+
for node in node_manager.nodes():
|
|
983
|
+
if isinstance(node, NodeManager):
|
|
984
|
+
print(node.dump(node.get_manager_name()))
|
|
985
|
+
node_managers.append(node)
|
|
986
|
+
for tree_node in self.get_tree_nodes():
|
|
987
|
+
stree = tree_node.symbol_tree
|
|
988
|
+
stree.print_node_tabulate(all_nodes)
|
|
900
989
|
|
|
901
990
|
def dump(self):
|
|
902
991
|
"""Dump graph."""
|
|
903
992
|
dump_st = SymbolTreeDumper(self)
|
|
904
993
|
dump_st.dump()
|
|
905
994
|
|
|
906
|
-
def
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
995
|
+
def check_body_exist(self, body, code_bodies):
|
|
996
|
+
"""Check whether body already exist in code_bodies"""
|
|
997
|
+
# Check import ast node exist by saving import code string to self._tmp_import_strs
|
|
998
|
+
if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
|
|
999
|
+
import_str = astunparse.unparse(body)
|
|
1000
|
+
if import_str in self._tmp_import_strs:
|
|
1001
|
+
return True
|
|
1002
|
+
self._tmp_import_strs.append(import_str)
|
|
1003
|
+
return False
|
|
1004
|
+
|
|
1005
|
+
# Check ClassDef ast node exist by using AstClassFinder
|
|
1006
|
+
if isinstance(body, ast.ClassDef):
|
|
1007
|
+
if sys.version_info >= (3, 9):
|
|
1008
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1009
|
+
else:
|
|
1010
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies))
|
|
1011
|
+
results = class_finder.find_all(body.name)
|
|
1012
|
+
return bool(results)
|
|
1013
|
+
|
|
1014
|
+
# Check FunctionDef ast node exist by using AstFunctionFinder
|
|
1015
|
+
if isinstance(body, ast.FunctionDef):
|
|
1016
|
+
if sys.version_info >= (3, 9):
|
|
1017
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1018
|
+
else:
|
|
1019
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
|
|
1020
|
+
results = function_finder.find_all(body.name)
|
|
1021
|
+
return bool(results)
|
|
1022
|
+
|
|
1023
|
+
return False
|
|
1024
|
+
|
|
1025
|
+
def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
|
|
1026
|
+
"""
|
|
1027
|
+
For the unmodified symbol tree, only one definition code remains in the generated code.
|
|
1028
|
+
Everywhere else calling this symbol tree will use the class in this definition code.
|
|
1029
|
+
"""
|
|
1030
|
+
# all modified ast.ClassDef will be exported to code
|
|
1031
|
+
if stree.is_modified():
|
|
1032
|
+
return False
|
|
1033
|
+
# all un-modified ast.ClassDef only keep one instance
|
|
1034
|
+
first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
|
|
1035
|
+
if first_cls_name is None:
|
|
1036
|
+
class_ast = stree.get_class_ast()
|
|
1037
|
+
if class_ast:
|
|
1038
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
|
|
1039
|
+
return False
|
|
1040
|
+
# Un-modified ast.ClassDef already exist in code_bodies,
|
|
1041
|
+
# replace class name to class name of first un-modified ast.ClassDef.
|
|
1042
|
+
if sys.version_info >= (3, 9):
|
|
1043
|
+
replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1044
|
+
else:
|
|
1045
|
+
replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1046
|
+
replacer.replace_all(stree.get_class_ast().name, first_cls_name)
|
|
1047
|
+
self._tmp_replacers.append(replacer)
|
|
1048
|
+
return True
|
|
1049
|
+
|
|
1050
|
+
def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
|
|
1051
|
+
"""
|
|
1052
|
+
Convert nodes in stree to code_bodies
|
|
1053
|
+
|
|
1054
|
+
1. Add import asts into code_bodies
|
|
1055
|
+
2. Add class, function and other type of asts into code_bodies
|
|
1056
|
+
3. Add father class asts into code_bodies
|
|
1057
|
+
4. Add external function asts into code_bodies
|
|
1058
|
+
5. Add subtrees to code_bodies
|
|
1059
|
+
5.1 Add subtrees in construct to code_bodies
|
|
1060
|
+
5.2 Add subtrees in CellContainers to code_bodies
|
|
1061
|
+
|
|
1062
|
+
"""
|
|
1063
|
+
# Add import asts into code_bodies
|
|
1064
|
+
for body in stree.get_import_asts():
|
|
1065
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1066
|
+
code_bodies.insert(insert_pos, body)
|
|
1067
|
+
insert_pos += 1
|
|
1068
|
+
|
|
1069
|
+
# Add class, function and other type of asts into code_bodies
|
|
1070
|
+
if stree.get_module_ast():
|
|
1071
|
+
for body in stree.get_module_ast().body:
|
|
1072
|
+
if self.check_body_exist(body, code_bodies):
|
|
1073
|
+
continue
|
|
1074
|
+
if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
|
|
1075
|
+
code_bodies.insert(insert_pos, body)
|
|
1076
|
+
else:
|
|
1077
|
+
code_bodies.append(body)
|
|
1078
|
+
|
|
1079
|
+
# Add father class asts into code_bodies
|
|
1080
|
+
for body in reversed(stree.get_father_class_ast()):
|
|
1081
|
+
if self.check_body_exist(body, code_bodies):
|
|
1082
|
+
# remove exist ast in old position, then insert ast to upper position
|
|
1083
|
+
if sys.version_info >= (3, 9):
|
|
1084
|
+
exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
|
|
1085
|
+
else:
|
|
1086
|
+
exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
|
|
1087
|
+
code_bodies.remove(exist_ast)
|
|
1088
|
+
code_bodies.insert(insert_pos, body)
|
|
1089
|
+
|
|
1090
|
+
# Add external asts into code_bodies
|
|
1091
|
+
for body in stree.get_external_ast():
|
|
1092
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1093
|
+
code_bodies.insert(insert_pos, body)
|
|
1094
|
+
insert_pos += 1
|
|
1095
|
+
|
|
1096
|
+
# Add subtrees to code_bodies
|
|
1097
|
+
for node in stree.get_tree_nodes():
|
|
1098
|
+
sub_stree = node.symbol_tree
|
|
1099
|
+
# Ignore TreeNode create by function in the class
|
|
1100
|
+
if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
|
|
1101
|
+
continue
|
|
1102
|
+
# For the unmodified class, update class name to name of first class
|
|
1103
|
+
if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
|
|
1104
|
+
continue
|
|
1105
|
+
self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
|
|
913
1106
|
|
|
914
1107
|
def get_code(self) -> str:
|
|
915
1108
|
"""
|
|
@@ -918,34 +1111,22 @@ class SymbolTree(Observer, Observable):
|
|
|
918
1111
|
Returns:
|
|
919
1112
|
A str represents source code of modified network.
|
|
920
1113
|
"""
|
|
921
|
-
self.
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
self.
|
|
1114
|
+
self._tmp_import_strs.clear()
|
|
1115
|
+
self._tmp_unmodified_strees.clear()
|
|
1116
|
+
self._tmp_replacers.clear()
|
|
1117
|
+
code_bodies = []
|
|
1118
|
+
self.convert_stree_to_code_bodies(self, code_bodies)
|
|
1119
|
+
if sys.version_info >= (3, 9):
|
|
1120
|
+
gencode_module = ast.Module(body=code_bodies, type_ignores=[])
|
|
1121
|
+
else:
|
|
1122
|
+
gencode_module = ast.Module(body=code_bodies)
|
|
1123
|
+
SymbolTree._remove_unused_import(gencode_module)
|
|
1124
|
+
SymbolTree._remove_duplicated_import(gencode_module)
|
|
926
1125
|
ast.fix_missing_locations(self._module_ast)
|
|
927
|
-
|
|
928
|
-
# Replace duplicated ast.ClassDef reference in main-ClassDef
|
|
929
|
-
seen_class: {type, str} = {}
|
|
930
|
-
allow_class_name = [self._class_ast.name]
|
|
931
|
-
replacers = []
|
|
932
|
-
SymbolTree._find_all_class_in_symboltree(self, seen_class, allow_class_name, replacers)
|
|
933
|
-
# Add all non-ClassDef body to gencode_module
|
|
934
|
-
# Add all ClassDef in allow_class_name to gencode_module
|
|
935
|
-
# Use gencode_module to generate code
|
|
936
|
-
bodies = []
|
|
937
|
-
for body in self._module_ast.body:
|
|
938
|
-
if not isinstance(body, ast.ClassDef):
|
|
939
|
-
bodies.append(body)
|
|
940
|
-
continue
|
|
941
|
-
if body.name in allow_class_name:
|
|
942
|
-
bodies.append(body)
|
|
943
|
-
gencode_module = ast.Module(body=bodies)
|
|
944
|
-
if_fixer = IfFixer()
|
|
945
|
-
if_fixer.fix(gencode_module)
|
|
1126
|
+
IfFixer().fix(gencode_module)
|
|
946
1127
|
code = astunparse.unparse(gencode_module)
|
|
947
|
-
#
|
|
948
|
-
for replacer in
|
|
1128
|
+
# Revert the class name to its original state
|
|
1129
|
+
for replacer in self._tmp_replacers:
|
|
949
1130
|
replacer.undo_all()
|
|
950
1131
|
return code
|
|
951
1132
|
|
|
@@ -979,251 +1160,71 @@ class SymbolTree(Observer, Observable):
|
|
|
979
1160
|
f.write(source.encode('utf-8'))
|
|
980
1161
|
f.flush()
|
|
981
1162
|
|
|
982
|
-
def
|
|
1163
|
+
def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
|
|
1164
|
+
node_manager: NodeManager):
|
|
983
1165
|
""" insert_to_ast_while_insert_node. """
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
992
|
-
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
993
|
-
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
994
|
-
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
995
|
-
|
|
996
|
-
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
997
|
-
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
998
|
-
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
999
|
-
|
|
1000
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1001
|
-
None if position is None else position.node.get_ast(),
|
|
1002
|
-
position.before_node)
|
|
1003
|
-
sub_stree: SymbolTree = node.symbol_tree
|
|
1004
|
-
from .symbol_tree_builder import SymbolTreeBuilder
|
|
1005
|
-
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
1166
|
+
if new_node.get_node_type() == NodeType.Input:
|
|
1167
|
+
# insert a new input
|
|
1168
|
+
self._inputs.append(new_node)
|
|
1169
|
+
ast_construct = self.get_ast_root()
|
|
1170
|
+
arg: str = new_node.get_targets()[0].value
|
|
1171
|
+
ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
|
|
1172
|
+
AstModifier.append_arg_to_function(ast_construct, ast_arg)
|
|
1006
1173
|
else:
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
body = self._module_ast.body[i]
|
|
1022
|
-
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
|
1023
|
-
continue
|
|
1024
|
-
if isinstance(body, ast.Import):
|
|
1025
|
-
continue
|
|
1026
|
-
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
|
1027
|
-
self._module_ast.body.remove(body)
|
|
1028
|
-
continue
|
|
1029
|
-
for alias in body.names:
|
|
1030
|
-
name = alias.asname if alias.asname else alias.name
|
|
1031
|
-
if not str_checker.check(name):
|
|
1032
|
-
if len(body.names) == 1:
|
|
1033
|
-
self._module_ast.body.remove(body)
|
|
1034
|
-
i += 1
|
|
1035
|
-
else:
|
|
1036
|
-
body.names.remove(alias)
|
|
1037
|
-
|
|
1038
|
-
def _replace_container_node(self, old_node, new_nodes):
|
|
1039
|
-
cellcontainer = getattr(old_node, "container")
|
|
1040
|
-
index = cellcontainer.node_list.index(old_node)
|
|
1041
|
-
for n in reversed(new_nodes):
|
|
1042
|
-
cellcontainer.insert(index, n)
|
|
1043
|
-
index = cellcontainer.node_list.index(old_node)
|
|
1044
|
-
cellcontainer.erase(old_node)
|
|
1045
|
-
|
|
1046
|
-
def _filter_out_to_delete_field(self, to_delete_field):
|
|
1047
|
-
"""filter out used field from `to_delete_field`"""
|
|
1048
|
-
for func_def in self._class_ast.body:
|
|
1049
|
-
if not isinstance(func_def, ast.FunctionDef):
|
|
1050
|
-
continue
|
|
1051
|
-
if func_def.name != "__init__":
|
|
1052
|
-
to_delete_to_delete_keys = []
|
|
1053
|
-
property_checker = CheckPropertyIsUsed(func_def)
|
|
1054
|
-
for key, _ in self._deleted_field.items():
|
|
1055
|
-
if property_checker.check("self", key):
|
|
1056
|
-
to_delete_to_delete_keys.append(key)
|
|
1057
|
-
property_checker = CheckPropertyIsUsed(func_def)
|
|
1058
|
-
for key in to_delete_to_delete_keys:
|
|
1059
|
-
self._deleted_field.pop(key)
|
|
1174
|
+
# insert a new assign statement
|
|
1175
|
+
ast_assign = new_node.get_ast()
|
|
1176
|
+
if ast_assign is None:
|
|
1177
|
+
func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
|
|
1178
|
+
new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
|
|
1179
|
+
ast_assign = new_node.update_ast_node()
|
|
1180
|
+
if not isinstance(ast_assign, ast.Assign):
|
|
1181
|
+
raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
|
|
1182
|
+
# Save instance into _origin_network.
|
|
1183
|
+
setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
|
|
1184
|
+
# Insert ast to __init__ function
|
|
1185
|
+
if isinstance(new_node, TreeNode):
|
|
1186
|
+
init_code = f"self.{new_node.get_name()} = " \
|
|
1187
|
+
f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
|
|
1060
1188
|
else:
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
self._deleted_field.pop(key)
|
|
1072
|
-
|
|
1073
|
-
def _remove_unused_field(self):
|
|
1074
|
-
"""remove unused field in __init__ function"""
|
|
1075
|
-
multi_targets = []
|
|
1076
|
-
self._deleted_field = {}
|
|
1077
|
-
for index, body in enumerate(self._init_func_ast.body):
|
|
1078
|
-
if not isinstance(body, ast.Assign):
|
|
1079
|
-
continue
|
|
1080
|
-
targets = body.targets
|
|
1081
|
-
for target in targets:
|
|
1082
|
-
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
|
|
1083
|
-
and target.value.id == "self":
|
|
1084
|
-
self._deleted_field[target.attr] = index
|
|
1085
|
-
if len(targets) > 1:
|
|
1086
|
-
multi_targets.append(index)
|
|
1087
|
-
self._filter_out_to_delete_field(self._deleted_field)
|
|
1088
|
-
for i in range(len(self._init_func_ast.body) - 1, -1, -1):
|
|
1089
|
-
if i in self._deleted_field.values():
|
|
1090
|
-
if i in multi_targets:
|
|
1091
|
-
raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
|
|
1092
|
-
AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
|
|
1093
|
-
ast.fix_missing_locations(self._init_func_ast)
|
|
1094
|
-
|
|
1095
|
-
def _remove_duplicated_import(self):
|
|
1096
|
-
"""Remove duplicated import of 'net'."""
|
|
1097
|
-
imports = []
|
|
1098
|
-
for body in self._module_ast.body:
|
|
1099
|
-
if isinstance(body, (ast.ImportFrom, ast.Import)):
|
|
1100
|
-
import_str = astunparse.unparse(body)
|
|
1101
|
-
if import_str not in imports:
|
|
1102
|
-
imports.append(import_str)
|
|
1103
|
-
else:
|
|
1104
|
-
self._module_ast.body.remove(body)
|
|
1189
|
+
init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
|
|
1190
|
+
init_ast = ast.parse(init_code).body[0]
|
|
1191
|
+
AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
|
|
1192
|
+
# Insert ast to construct_function/class_internal_function
|
|
1193
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
1194
|
+
ast_functiondef = node_manager.get_ast_functiondef()
|
|
1195
|
+
if not ast_functiondef:
|
|
1196
|
+
raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
|
|
1197
|
+
"when inserting the ast.")
|
|
1198
|
+
AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
|
|
1105
1199
|
|
|
1106
1200
|
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1107
1201
|
if isinstance(node_or_name, str):
|
|
1108
1202
|
return self.get_node(node_or_name)
|
|
1109
1203
|
return node_or_name
|
|
1110
1204
|
|
|
1111
|
-
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
|
1112
|
-
"""
|
|
1113
|
-
Insert a node-tree into SymbolTree.
|
|
1114
|
-
Note:
|
|
1115
|
-
Inputs of intra sub-tree nodes need to be welly set.
|
|
1116
|
-
|
|
1117
|
-
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
|
1118
|
-
|
|
1119
|
-
Args:
|
|
1120
|
-
position (Position): A Position indicates an insert position point.
|
|
1121
|
-
root (Node): An instance of node as root of node-tree to be inserted in.
|
|
1122
|
-
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
1123
|
-
True.
|
|
1124
|
-
|
|
1125
|
-
Returns:
|
|
1126
|
-
An instance of node as root node of node-tree which has been inserted into SymbolTree.
|
|
1127
|
-
|
|
1128
|
-
Raises:
|
|
1129
|
-
RuntimeError: If 'position' is not in current SymbolTree.
|
|
1130
|
-
"""
|
|
1131
|
-
|
|
1132
|
-
# if position not in current SymbolTree
|
|
1133
|
-
if position.symbol_tree is not self:
|
|
1134
|
-
raise RuntimeError("Position is not in current SymbolTree: ", position)
|
|
1135
|
-
|
|
1136
|
-
queue: [Node] = [root]
|
|
1137
|
-
todos: [] = []
|
|
1138
|
-
inputs_list: [] = []
|
|
1139
|
-
while queue:
|
|
1140
|
-
cur_node = queue.pop(0)
|
|
1141
|
-
if cur_node in todos:
|
|
1142
|
-
continue
|
|
1143
|
-
todos.append(cur_node)
|
|
1144
|
-
node_inputs = cur_node.get_inputs()
|
|
1145
|
-
inputs_list.append(node_inputs)
|
|
1146
|
-
for node_input in node_inputs:
|
|
1147
|
-
if node_input is not None:
|
|
1148
|
-
queue.append(node_input)
|
|
1149
|
-
todos.reverse()
|
|
1150
|
-
inputs_list.reverse()
|
|
1151
|
-
for index, todo in enumerate(todos):
|
|
1152
|
-
self.insert_node(position, todo, insert_to_ast)
|
|
1153
|
-
position = self.after(todo)
|
|
1154
|
-
# relink input of node
|
|
1155
|
-
original_inputs = inputs_list[index]
|
|
1156
|
-
for arg_idx, original_input in enumerate(original_inputs):
|
|
1157
|
-
if original_input is not None:
|
|
1158
|
-
self.set_node_arg_by_node(todo, arg_idx, original_input)
|
|
1159
|
-
return root
|
|
1160
|
-
|
|
1161
|
-
def _add_node2nodes(self, node: Node):
|
|
1162
|
-
"""
|
|
1163
|
-
Add `node` to `_nodes` dict.
|
|
1164
|
-
|
|
1165
|
-
Args:
|
|
1166
|
-
node (Node): A Node to be added into `_nodes`.
|
|
1167
|
-
|
|
1168
|
-
Raises:
|
|
1169
|
-
RuntimeError: If name of the node is duplicated.
|
|
1170
|
-
"""
|
|
1171
|
-
node_name = node.get_name()
|
|
1172
|
-
if self._nodes.get(node_name) is not None:
|
|
1173
|
-
raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
|
|
1174
|
-
node)
|
|
1175
|
-
self._nodes[node_name] = node
|
|
1176
|
-
|
|
1177
|
-
def _insert_node(self, position: Optional[Position], node: Node):
|
|
1178
|
-
"""
|
|
1179
|
-
Insert a node into SymbolTree.
|
|
1180
|
-
1. Add `node` to `_nodes`.
|
|
1181
|
-
2. Insert `node` to node list(source code order).
|
|
1182
|
-
3. Update topological relation and update inputs of `node`.
|
|
1183
|
-
|
|
1184
|
-
Args:
|
|
1185
|
-
position ([Position, optional]): Indicates node insert position. Position is None when inserting first node
|
|
1186
|
-
of SymbolTree.
|
|
1187
|
-
node (Node): A Node to be inserted into SymbolTree.
|
|
1188
|
-
|
|
1189
|
-
Raises:
|
|
1190
|
-
RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
|
|
1191
|
-
unless inserting first node.
|
|
1192
|
-
"""
|
|
1193
|
-
if position is None:
|
|
1194
|
-
if self._nodes:
|
|
1195
|
-
raise RuntimeError("self._nodes should be empty")
|
|
1196
|
-
self._head = node
|
|
1197
|
-
else:
|
|
1198
|
-
if position.before_node:
|
|
1199
|
-
position.node.insert_before(node)
|
|
1200
|
-
else:
|
|
1201
|
-
position.node.insert_after(node)
|
|
1202
|
-
self._tail = node
|
|
1203
|
-
self._add_node2nodes(node)
|
|
1204
|
-
self._topo_mgr.on_insert_node(node)
|
|
1205
|
-
node.set_belong_symbol_tree(self)
|
|
1206
|
-
|
|
1207
1205
|
def _handle_custom_obj_in_normalized_args(self, node: Node):
|
|
1208
1206
|
"""
|
|
1209
|
-
Convert CustomObjValue type argument to NamingValue type argument by storing custom object
|
|
1207
|
+
Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
|
|
1210
1208
|
|
|
1211
1209
|
Args:
|
|
1212
1210
|
node (Node): A Node whose arguments and keyword arguments to be handled.
|
|
1213
1211
|
"""
|
|
1214
|
-
|
|
1215
|
-
for
|
|
1212
|
+
normalized_args: {str, ScopedValue} = {}
|
|
1213
|
+
for key, value in node.get_normalized_args().items():
|
|
1216
1214
|
if not isinstance(value, ScopedValue):
|
|
1217
1215
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
1218
1216
|
if value.type == ValueType.CustomObjValue:
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1217
|
+
# Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
|
|
1218
|
+
arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
|
|
1219
|
+
setattr(self._origin_network, arg_name, value.value)
|
|
1220
|
+
# Add new code to __init__(): self.arg_name = obj.arg_name
|
|
1221
|
+
new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
|
|
1222
|
+
self._init_func_ast.body.append(new_ast)
|
|
1223
|
+
# Modify node's normalized_args: CustomObjValue -> self.arg_name
|
|
1224
|
+
normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
|
|
1224
1225
|
else:
|
|
1225
|
-
|
|
1226
|
-
node.set_normalized_args(
|
|
1226
|
+
normalized_args[key] = value
|
|
1227
|
+
node.set_normalized_args(normalized_args)
|
|
1227
1228
|
|
|
1228
1229
|
def _get_cls_through_file(self):
|
|
1229
1230
|
"""
|
|
@@ -1235,12 +1236,14 @@ class SymbolTree(Observer, Observable):
|
|
|
1235
1236
|
Returns:
|
|
1236
1237
|
A class handle.
|
|
1237
1238
|
"""
|
|
1238
|
-
self._update_container()
|
|
1239
1239
|
file_path = os.getcwd()
|
|
1240
1240
|
file_path = os.path.join(file_path, "rewritten_network")
|
|
1241
1241
|
if not os.path.exists(file_path):
|
|
1242
|
-
|
|
1243
|
-
|
|
1242
|
+
try:
|
|
1243
|
+
os.mkdir(file_path, mode=0o700)
|
|
1244
|
+
except FileExistsError:
|
|
1245
|
+
pass
|
|
1246
|
+
file_name = f"{self._opt_cls_name}_{id(self)}.py"
|
|
1244
1247
|
network_file = os.path.join(file_path, file_name)
|
|
1245
1248
|
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1246
1249
|
source = self.get_code()
|
|
@@ -1277,21 +1280,6 @@ class SymbolTree(Observer, Observable):
|
|
|
1277
1280
|
self._modified = True
|
|
1278
1281
|
self.changed(event)
|
|
1279
1282
|
|
|
1280
|
-
def _update_container(self):
|
|
1281
|
-
"""Update instance of node in container."""
|
|
1282
|
-
for node in self.nodes():
|
|
1283
|
-
index = 0
|
|
1284
|
-
if node.get_node_type() == NodeType.CellContainer:
|
|
1285
|
-
for n in node.node_list:
|
|
1286
|
-
if not n.valid:
|
|
1287
|
-
continue
|
|
1288
|
-
if n.get_node_type() == NodeType.Tree:
|
|
1289
|
-
obj = n.symbol_tree.get_network()
|
|
1290
|
-
node.get_instance()[index] = obj
|
|
1291
|
-
else:
|
|
1292
|
-
node.get_instance()[index] = n.get_instance()
|
|
1293
|
-
index += 1
|
|
1294
|
-
|
|
1295
1283
|
def _cal_difference_set(self, input, other):
|
|
1296
1284
|
"""Calculate different set of two sets."""
|
|
1297
1285
|
set1 = set(input)
|
|
@@ -1313,43 +1301,3 @@ class SymbolTree(Observer, Observable):
|
|
|
1313
1301
|
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1314
1302
|
for p in primitives:
|
|
1315
1303
|
new_net._primitives[p] = self._origin_network._primitives[p]
|
|
1316
|
-
|
|
1317
|
-
def _create_call_function(self, func, targets, args, kwargs):
|
|
1318
|
-
"""
|
|
1319
|
-
Create a Node object and generate the execution code to insert into the source code.
|
|
1320
|
-
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
|
1321
|
-
|
|
1322
|
-
Args:
|
|
1323
|
-
func (FunctionType) - The function to be called.
|
|
1324
|
-
targets (list [str]) - indicates the output name. As the output of the node in the source code.
|
|
1325
|
-
args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
|
|
1326
|
-
code. The default value is None, which means there is no parameter input in the cell.
|
|
1327
|
-
kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
|
|
1328
|
-
input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
|
|
1329
|
-
code as the 'kwargs' in the statement expression. The default value is None, which means there is no
|
|
1330
|
-
'kwargs' input.
|
|
1331
|
-
|
|
1332
|
-
Returns:
|
|
1333
|
-
An instance of `Node`.
|
|
1334
|
-
"""
|
|
1335
|
-
if not isinstance(func, types.FunctionType):
|
|
1336
|
-
raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
|
|
1337
|
-
|
|
1338
|
-
_package = func.__globals__['__package__']
|
|
1339
|
-
func_name = ".".join([_package, func.__name__]) if _package else func.__name__
|
|
1340
|
-
|
|
1341
|
-
ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
|
|
1342
|
-
scope_targets = [ScopedValue.create_naming_value(targets[0])]
|
|
1343
|
-
scope_func = ScopedValue.create_naming_value(func_name, "")
|
|
1344
|
-
call_args = list()
|
|
1345
|
-
for arg in args:
|
|
1346
|
-
if isinstance(arg, Node):
|
|
1347
|
-
call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
|
|
1348
|
-
else:
|
|
1349
|
-
call_args.append(ScopedValue.create_variable_value(arg))
|
|
1350
|
-
call_kwargs = {}
|
|
1351
|
-
for k, v in kwargs.items():
|
|
1352
|
-
call_kwargs[k] = ScopedValue.create_variable_value(v)
|
|
1353
|
-
node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
|
|
1354
|
-
call_kwargs)
|
|
1355
|
-
return node
|