mindspore 2.0.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +647 -818
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/rewrite/symbol_tree.py
CHANGED
|
@@ -14,30 +14,31 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
|
16
16
|
import stat
|
|
17
|
-
from typing import Optional, Union, Tuple, Any
|
|
17
|
+
from typing import Optional, Union, Tuple, Any, Dict, List
|
|
18
18
|
import os
|
|
19
19
|
import sys
|
|
20
20
|
import ast
|
|
21
|
-
import importlib
|
|
22
|
-
import types
|
|
21
|
+
import importlib.util
|
|
23
22
|
import time
|
|
24
|
-
import astunparse
|
|
25
23
|
|
|
26
24
|
from mindspore.nn import Cell
|
|
27
25
|
from mindspore import log as logger
|
|
28
|
-
from
|
|
29
|
-
from .node import Node, TreeNode
|
|
26
|
+
from .node.node import Node, TreeNode
|
|
30
27
|
from .api.node_type import NodeType
|
|
31
|
-
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder,
|
|
28
|
+
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
|
|
32
29
|
from .api.scoped_value import ScopedValue, ValueType
|
|
33
30
|
from .symbol_tree_dumper import SymbolTreeDumper
|
|
34
|
-
from .
|
|
31
|
+
from .node.node_topological_manager import TopoManager
|
|
35
32
|
from .namer import TargetNamer, NodeNamer, ClassNamer
|
|
36
33
|
from .common.observer import Observer
|
|
37
34
|
from .common.observable import Observable
|
|
38
35
|
from .common.event import Event
|
|
39
|
-
from .
|
|
36
|
+
from .node.node_manager import NodeManager
|
|
40
37
|
|
|
38
|
+
if sys.version_info >= (3, 9):
|
|
39
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
40
|
+
else:
|
|
41
|
+
import astunparse
|
|
41
42
|
|
|
42
43
|
class Position:
|
|
43
44
|
"""
|
|
@@ -81,6 +82,7 @@ class FieldFinder(AstFinder):
|
|
|
81
82
|
Args:
|
|
82
83
|
scope (ast.AST): An instance of ast node as search scope.
|
|
83
84
|
"""
|
|
85
|
+
|
|
84
86
|
def __init__(self, scope: ast.AST):
|
|
85
87
|
super().__init__(scope)
|
|
86
88
|
self._result = False
|
|
@@ -134,7 +136,7 @@ class IfFixer(ast.NodeTransformer):
|
|
|
134
136
|
self.generic_visit(node)
|
|
135
137
|
|
|
136
138
|
|
|
137
|
-
class SymbolTree(Observer, Observable):
|
|
139
|
+
class SymbolTree(Observer, Observable, NodeManager):
|
|
138
140
|
"""
|
|
139
141
|
A symbol-tree usually corresponding to forward method of a network.
|
|
140
142
|
|
|
@@ -147,227 +149,138 @@ class SymbolTree(Observer, Observable):
|
|
|
147
149
|
"""
|
|
148
150
|
|
|
149
151
|
def __init__(self, origin_network: Cell, module_ast: ast.Module):
|
|
150
|
-
|
|
152
|
+
Observer.__init__(self)
|
|
151
153
|
Observable.__init__(self)
|
|
152
|
-
|
|
154
|
+
self._node_namer = NodeNamer()
|
|
155
|
+
self._node_namer.add_name('obj')
|
|
156
|
+
NodeManager.__init__(self, self._node_namer)
|
|
157
|
+
NodeManager.reg_observer(self, observer=self)
|
|
153
158
|
# init unique-namers
|
|
154
159
|
self._target_namer = TargetNamer()
|
|
155
|
-
|
|
156
|
-
# name or node would use as name of field, so name of origin network handler field should be added into \
|
|
157
|
-
# _node_name_namer.
|
|
158
|
-
self._node_name_namer.add_name(origin_network_key)
|
|
159
|
-
self._topo_mgr = TopoManager()
|
|
160
|
-
self._topo_mgr.reg_observer(self)
|
|
161
|
-
|
|
162
|
-
self._nodes: {str, Node} = {}
|
|
163
|
-
# parameters of forward method
|
|
164
|
-
self._inputs: [Node] = []
|
|
160
|
+
# input arguments of function
|
|
165
161
|
self._ori_cls_name = type(origin_network).__name__
|
|
166
162
|
self._opt_cls_name = ClassNamer.instance().get_name(self._ori_cls_name)
|
|
163
|
+
NodeManager.set_manager_name(self, self._opt_cls_name)
|
|
167
164
|
self._origin_network = origin_network
|
|
168
165
|
self._module_ast: ast.Module = module_ast
|
|
166
|
+
self._import_asts: Optional[ast.Ast] = []
|
|
169
167
|
self._class_ast: Optional[ast.ClassDef] = None
|
|
170
168
|
self._root_ast: Optional[ast.FunctionDef] = None
|
|
171
169
|
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
172
170
|
self._deleted_field = {}
|
|
173
171
|
self._deleted_node = []
|
|
174
|
-
self.
|
|
172
|
+
self._external_ast = []
|
|
175
173
|
self._father_class_ast = []
|
|
176
|
-
|
|
177
|
-
# head node is always point to the first node(in source code order) of SymbolTree
|
|
178
|
-
self._head = None
|
|
179
|
-
# tail node is always point to the last node(in source code order) of SymbolTree
|
|
180
|
-
self._tail = None
|
|
181
|
-
self._return: Optional[Node] = None
|
|
182
|
-
|
|
183
174
|
self._modified = False
|
|
184
|
-
self._node_visitor = None
|
|
185
|
-
|
|
186
175
|
self._tmp_file_limits = 20
|
|
187
176
|
self._tmp_files = []
|
|
188
177
|
self._saved_file_name = "./network_define.py"
|
|
178
|
+
# used to insert "sys.path.append(xxx)"
|
|
179
|
+
self._net_file_paths = []
|
|
180
|
+
self._tmp_import_strs = []
|
|
181
|
+
self._tmp_unmodified_strees: {type, str} = {}
|
|
182
|
+
self._tmp_replacers = []
|
|
183
|
+
# Record imported modules and names of each files
|
|
184
|
+
# The meanings of `module` and `name` are like code: from `module` import `nameA`, `nameB`
|
|
185
|
+
# Format: {file_path: {module: [name, ...], ...}, ...}
|
|
186
|
+
self._imported_modules: Dict[str, Dict[str, List[str]]] = {}
|
|
189
187
|
|
|
190
188
|
def __del__(self):
|
|
191
189
|
for tmp_file in self._tmp_files:
|
|
192
190
|
tmp_file.close()
|
|
193
191
|
|
|
194
192
|
@staticmethod
|
|
195
|
-
def
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
for node in nodes:
|
|
202
|
-
for arg in node.get_args():
|
|
203
|
-
if consumers.get(arg):
|
|
204
|
-
consumers[arg].append(node)
|
|
205
|
-
else:
|
|
206
|
-
consumers[arg] = [node]
|
|
207
|
-
for _, arg in node.get_kwargs():
|
|
208
|
-
if consumers.get(arg):
|
|
209
|
-
consumers[arg].append(node)
|
|
210
|
-
else:
|
|
211
|
-
consumers[arg] = [node]
|
|
212
|
-
for target in node.get_targets():
|
|
213
|
-
if providers.get(target) is not None:
|
|
214
|
-
raise RuntimeError(f"Target({target}) of node duplicated")
|
|
215
|
-
providers[target] = node
|
|
216
|
-
return consumers, providers
|
|
217
|
-
|
|
218
|
-
@staticmethod
|
|
219
|
-
def _link_nodes_and_find_root(nodes: [Node]) -> Node:
|
|
220
|
-
"""
|
|
221
|
-
Find inputs for all nodes created by Replacement according to their targets and arguments.
|
|
222
|
-
|
|
223
|
-
Find root node of all nodes created by Replacement. One and Only one root should be found.
|
|
224
|
-
|
|
225
|
-
Args:
|
|
226
|
-
nodes (list[Node]): A list of instance of Node created by Replacement.
|
|
227
|
-
|
|
228
|
-
Returns:
|
|
229
|
-
An instance of Node represents root of input nodes.
|
|
230
|
-
"""
|
|
231
|
-
consumers, providers = SymbolTree._find_consumers_and_providers(nodes)
|
|
232
|
-
# find root node
|
|
233
|
-
root = None
|
|
234
|
-
for node in nodes:
|
|
235
|
-
used = 0
|
|
236
|
-
for target in node.get_targets():
|
|
237
|
-
cur_consumers = consumers.get(target)
|
|
238
|
-
if not cur_consumers:
|
|
239
|
-
continue
|
|
240
|
-
for cur_consumer in cur_consumers:
|
|
241
|
-
if id(cur_consumer) != id(node):
|
|
242
|
-
used += 1
|
|
243
|
-
break
|
|
244
|
-
if used == 0:
|
|
245
|
-
if root is not None:
|
|
246
|
-
raise RuntimeError("Replacement should only has one root, found multi-root")
|
|
247
|
-
root = node
|
|
248
|
-
if root is None:
|
|
249
|
-
raise RuntimeError("Replacement should only has one root, found no root")
|
|
250
|
-
# link node's input
|
|
251
|
-
for node in nodes:
|
|
252
|
-
inputs = []
|
|
253
|
-
for _, arg in node.get_normalized_args().items():
|
|
254
|
-
node_input: Node = providers.get(arg)
|
|
255
|
-
if id(node_input) != id(node):
|
|
256
|
-
inputs.append(node_input)
|
|
257
|
-
node.set_inputs(inputs)
|
|
258
|
-
return root
|
|
259
|
-
|
|
260
|
-
@staticmethod
|
|
261
|
-
def _find_all_class_in_symboltree(stree: 'SymbolTree', seen_class: {type, str}, allow_class_name: [], replacers):
|
|
262
|
-
"""Find all non-duplicated class name of SymbolTree recursively."""
|
|
263
|
-
replacer = AstReplacer(stree._class_ast)
|
|
264
|
-
replacers.append(replacer)
|
|
265
|
-
for node in stree.nodes():
|
|
266
|
-
if not isinstance(node, TreeNode):
|
|
193
|
+
def _remove_unused_import(module_ast):
|
|
194
|
+
"""remove unused import in self._module_ast"""
|
|
195
|
+
str_checker = StrChecker(module_ast)
|
|
196
|
+
for i in range(len(module_ast.body) - 1, -1, -1):
|
|
197
|
+
body = module_ast.body[i]
|
|
198
|
+
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
|
267
199
|
continue
|
|
268
|
-
if
|
|
200
|
+
if isinstance(body, ast.Import):
|
|
269
201
|
continue
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
# all modified ast.ClassDef should export to code
|
|
273
|
-
if sub_stree._modified:
|
|
274
|
-
allow_class_name.append(sub_stree._class_ast.name)
|
|
202
|
+
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
|
203
|
+
module_ast.body.remove(body)
|
|
275
204
|
continue
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
205
|
+
for alias in body.names:
|
|
206
|
+
name = alias.asname if alias.asname else alias.name
|
|
207
|
+
if not str_checker.check(name):
|
|
208
|
+
if len(body.names) == 1:
|
|
209
|
+
module_ast.body.remove(body)
|
|
210
|
+
i += 1
|
|
211
|
+
else:
|
|
212
|
+
body.names.remove(alias)
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def _remove_duplicated_import(module_ast):
|
|
216
|
+
"""Remove duplicated import of 'net'."""
|
|
217
|
+
imports = set()
|
|
218
|
+
futures = set()
|
|
219
|
+
classes = set()
|
|
220
|
+
|
|
221
|
+
class TransImportNode(ast.NodeTransformer):
|
|
222
|
+
"""Find all import nodes from input ast node."""
|
|
223
|
+
|
|
224
|
+
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
225
|
+
class_str = astunparse.unparse(node)
|
|
226
|
+
if class_str not in classes:
|
|
227
|
+
classes.add(node.name)
|
|
228
|
+
return node
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
def visit_Try(self, node: ast.Try) -> Any:
|
|
232
|
+
if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
|
|
233
|
+
import_str = astunparse.unparse(node)
|
|
234
|
+
if import_str not in imports:
|
|
235
|
+
imports.add(import_str)
|
|
236
|
+
return node
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
def visit_Import(self, node: ast.Import) -> Any:
|
|
240
|
+
import_str = astunparse.unparse(node)
|
|
241
|
+
if import_str not in imports:
|
|
242
|
+
imports.add(import_str)
|
|
243
|
+
return node
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
247
|
+
"""
|
|
248
|
+
Once the father class 'A' is defined in the current module, all the next imported class 'A' should
|
|
249
|
+
be removed. e.g.
|
|
250
|
+
def class A():
|
|
251
|
+
...
|
|
252
|
+
from xxx import A, B
|
|
253
|
+
=>
|
|
254
|
+
def class A():
|
|
255
|
+
...
|
|
256
|
+
from xxx import B
|
|
257
|
+
"""
|
|
258
|
+
import_str = astunparse.unparse(node)
|
|
259
|
+
|
|
260
|
+
if import_str not in imports:
|
|
261
|
+
imports.add(import_str)
|
|
262
|
+
# remove "__future__" module
|
|
263
|
+
if node.module == '__future__':
|
|
264
|
+
futures.add(node.module)
|
|
265
|
+
return
|
|
266
|
+
# remove modules which have been defined in the code file
|
|
267
|
+
# it occurs when class A is a father class and other sub-classes import A
|
|
268
|
+
for alias in node.names[:]:
|
|
269
|
+
if alias.name in classes:
|
|
270
|
+
node.names.remove(alias)
|
|
271
|
+
# if the alias(es) in node.names are all removed, this import statement should be removed
|
|
272
|
+
if not node.names:
|
|
273
|
+
return
|
|
274
|
+
return node
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
get_node_handler = TransImportNode()
|
|
278
|
+
get_node_handler.generic_visit(module_ast)
|
|
283
279
|
|
|
284
280
|
def finish_build(self):
|
|
285
281
|
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
286
282
|
self.add_event(Event.TopologicalChangeEvent)
|
|
287
283
|
|
|
288
|
-
def _create_call_function(self, func, targets, args, kwargs):
|
|
289
|
-
"""
|
|
290
|
-
Create a Node object and generate the execution code to insert into the source code.
|
|
291
|
-
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
|
292
|
-
|
|
293
|
-
Args:
|
|
294
|
-
func (FunctionType) - The function to be called.
|
|
295
|
-
targets (list [str]) - indicates the output name. As the output of the node in the source code.
|
|
296
|
-
args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
|
|
297
|
-
code. The default value is None, which means there is no parameter input in the cell.
|
|
298
|
-
kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
|
|
299
|
-
input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
|
|
300
|
-
code as the 'kwargs' in the statement expression. The default value is None, which means there is no
|
|
301
|
-
'kwargs' input.
|
|
302
|
-
|
|
303
|
-
Returns:
|
|
304
|
-
An instance of `Node`.
|
|
305
|
-
"""
|
|
306
|
-
if not isinstance(func, types.FunctionType):
|
|
307
|
-
raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
|
|
308
|
-
|
|
309
|
-
_package = func.__globals__['__package__']
|
|
310
|
-
func_name = ".".join([_package, func.__name__]) if _package else func.__name__
|
|
311
|
-
|
|
312
|
-
ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
|
|
313
|
-
scope_targets = [ScopedValue.create_naming_value(targets[0])]
|
|
314
|
-
scope_func = ScopedValue.create_naming_value(func_name, "")
|
|
315
|
-
call_args = list()
|
|
316
|
-
for arg in args:
|
|
317
|
-
if isinstance(arg, Node):
|
|
318
|
-
call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
|
|
319
|
-
else:
|
|
320
|
-
call_args.append(ScopedValue.create_variable_value(arg))
|
|
321
|
-
call_kwargs = {}
|
|
322
|
-
for k, v in kwargs.items():
|
|
323
|
-
call_kwargs[k] = ScopedValue.create_variable_value(v)
|
|
324
|
-
node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
|
|
325
|
-
call_kwargs)
|
|
326
|
-
return node
|
|
327
|
-
|
|
328
|
-
def create_assign_node(self, targets, func_name, args, kwargs):
|
|
329
|
-
"""
|
|
330
|
-
Create a ast.Assign type node.
|
|
331
|
-
|
|
332
|
-
Args:
|
|
333
|
-
targets (list): _description_
|
|
334
|
-
func_name (_type_): _description_
|
|
335
|
-
args (_type_): _description_
|
|
336
|
-
kwargs (_type_): _description_
|
|
337
|
-
|
|
338
|
-
Returns:
|
|
339
|
-
_type_: _description_
|
|
340
|
-
"""
|
|
341
|
-
# create targets
|
|
342
|
-
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
343
|
-
# create call
|
|
344
|
-
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
345
|
-
ast_args = ast_creator_registry.get("Args")(args)
|
|
346
|
-
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
347
|
-
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
348
|
-
# create assign
|
|
349
|
-
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
350
|
-
return ast_node
|
|
351
|
-
|
|
352
|
-
def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
|
|
353
|
-
'''
|
|
354
|
-
Instantiate an instance of node whose type is `CallFunction`.
|
|
355
|
-
|
|
356
|
-
Args:
|
|
357
|
-
node_name (str): Name of node.
|
|
358
|
-
func_name (str): Name of function.
|
|
359
|
-
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
360
|
-
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
361
|
-
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
362
|
-
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
363
|
-
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
364
|
-
class.
|
|
365
|
-
'''
|
|
366
|
-
logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
|
|
367
|
-
node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
|
|
368
|
-
node.set_belong_symbol_tree(self)
|
|
369
|
-
return node
|
|
370
|
-
|
|
371
284
|
def get_ori_cls_name(self) -> str:
|
|
372
285
|
"""
|
|
373
286
|
Get class name of original network.
|
|
@@ -423,6 +336,7 @@ class SymbolTree(Observer, Observable):
|
|
|
423
336
|
corresponding network class.
|
|
424
337
|
"""
|
|
425
338
|
self._root_ast = ast_node
|
|
339
|
+
NodeManager.set_ast_functiondef(self, ast_node)
|
|
426
340
|
|
|
427
341
|
def get_class_ast(self):
|
|
428
342
|
"""
|
|
@@ -461,18 +375,6 @@ class SymbolTree(Observer, Observable):
|
|
|
461
375
|
"""
|
|
462
376
|
self._init_func_ast = ast_node
|
|
463
377
|
|
|
464
|
-
def get_inputs(self):
|
|
465
|
-
return self._inputs
|
|
466
|
-
|
|
467
|
-
def get_head_node(self):
|
|
468
|
-
"""
|
|
469
|
-
Getter of `_head` which represents the beginning node while iterating SymbolTree nodes.
|
|
470
|
-
|
|
471
|
-
Returns:
|
|
472
|
-
An instance of node.
|
|
473
|
-
"""
|
|
474
|
-
return self._head
|
|
475
|
-
|
|
476
378
|
def get_origin_network(self):
|
|
477
379
|
"""
|
|
478
380
|
Getter of `_origin_network`.
|
|
@@ -486,36 +388,53 @@ class SymbolTree(Observer, Observable):
|
|
|
486
388
|
"""Get dict of nodes"""
|
|
487
389
|
return self._nodes
|
|
488
390
|
|
|
489
|
-
def
|
|
391
|
+
def get_node_namer(self):
|
|
392
|
+
"""Get _node_namer"""
|
|
393
|
+
return self._node_namer
|
|
394
|
+
|
|
395
|
+
def is_modified(self):
|
|
490
396
|
"""
|
|
491
|
-
|
|
397
|
+
Check whether symbol tree is modified.
|
|
492
398
|
|
|
493
|
-
|
|
494
|
-
|
|
399
|
+
Symbol tree is considered as modified if operations like insert/replace/erase/set_arg is called after
|
|
400
|
+
the symbol tree is created.
|
|
495
401
|
"""
|
|
496
|
-
|
|
497
|
-
self._node_visitor = NodeVisitor(self)
|
|
498
|
-
it = iter(self._node_visitor)
|
|
402
|
+
return self._modified
|
|
499
403
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
except StopIteration:
|
|
504
|
-
return None
|
|
505
|
-
yield n
|
|
404
|
+
def set_modified_true(self):
|
|
405
|
+
"""
|
|
406
|
+
Set self._modified true.
|
|
506
407
|
|
|
507
|
-
|
|
408
|
+
Self._modified is set true when 'if' exists in the original network.
|
|
409
|
+
In this situation, different original network instance tends to be different.
|
|
410
|
+
Hence, the class name should be updated.
|
|
508
411
|
"""
|
|
509
|
-
|
|
412
|
+
self._modified = True
|
|
510
413
|
|
|
511
|
-
|
|
512
|
-
|
|
414
|
+
def get_import_asts(self):
|
|
415
|
+
"""Get _import_asts"""
|
|
416
|
+
return self._import_asts
|
|
513
417
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
418
|
+
def get_external_ast(self):
|
|
419
|
+
"""Get _external_ast"""
|
|
420
|
+
return self._external_ast
|
|
421
|
+
|
|
422
|
+
def get_father_class_ast(self):
|
|
423
|
+
"""Get _father_class_ast"""
|
|
424
|
+
return self._father_class_ast
|
|
517
425
|
|
|
518
|
-
|
|
426
|
+
def get_imported_modules(self, file_path: str):
|
|
427
|
+
"""Get all modules and module_paths in file of `file_path` ."""
|
|
428
|
+
return self._imported_modules.get(file_path, {})
|
|
429
|
+
|
|
430
|
+
def save_imported_modules(self, file_path: str, module: str, names: List[str]):
|
|
431
|
+
"""Save module and names into _imported_modules."""
|
|
432
|
+
imported_modules = self.get_imported_modules(file_path)
|
|
433
|
+
if imported_modules.get(module):
|
|
434
|
+
imported_modules[module].extend(names)
|
|
435
|
+
else:
|
|
436
|
+
imported_modules[module] = names
|
|
437
|
+
self._imported_modules[file_path] = imported_modules
|
|
519
438
|
|
|
520
439
|
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
|
521
440
|
"""
|
|
@@ -553,7 +472,7 @@ class SymbolTree(Observer, Observable):
|
|
|
553
472
|
return []
|
|
554
473
|
if real_node.get_node_type() == NodeType.Output:
|
|
555
474
|
return []
|
|
556
|
-
return
|
|
475
|
+
return TopoManager.get_node_users(real_node)
|
|
557
476
|
|
|
558
477
|
def before(self, node_or_name: Union[Node, str]) -> Position:
|
|
559
478
|
"""
|
|
@@ -606,9 +525,11 @@ class SymbolTree(Observer, Observable):
|
|
|
606
525
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
607
526
|
return Position.create(node.get_belong_symbol_tree(), node, False)
|
|
608
527
|
|
|
609
|
-
def insert_node(self,
|
|
528
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, node_manager: NodeManager = None,
|
|
529
|
+
insert_to_ast: bool = True):
|
|
610
530
|
"""
|
|
611
|
-
Insert a node
|
|
531
|
+
Insert a node before or after base_node.
|
|
532
|
+
|
|
612
533
|
Note:
|
|
613
534
|
Name of node will be unique while inserting node into SymbolTree.
|
|
614
535
|
|
|
@@ -627,52 +548,73 @@ class SymbolTree(Observer, Observable):
|
|
|
627
548
|
Topological relation is updated and inputs of corresponding node is updated.
|
|
628
549
|
|
|
629
550
|
Args:
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
551
|
+
new_node (Node): Node to be inserted.
|
|
552
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
553
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
554
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
555
|
+
NodeManager of symboltree's construct function.
|
|
556
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
634
557
|
|
|
635
558
|
Returns:
|
|
636
559
|
An instance of node which has been inserted into SymbolTree.
|
|
637
560
|
|
|
638
561
|
Raises:
|
|
639
|
-
|
|
562
|
+
ValueError: Node in the SymbolTree is inserted into SymbolTree again.
|
|
640
563
|
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
|
641
564
|
"""
|
|
642
|
-
if
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
565
|
+
if new_node.get_belong_symbol_tree():
|
|
566
|
+
raise ValueError(f"Node in the SymbolTree cannot be inserted into SymbolTree again: {new_node.get_name()}")
|
|
567
|
+
|
|
568
|
+
# Check if base_node in current SymbolTree
|
|
569
|
+
if base_node is not None:
|
|
570
|
+
stree = base_node.get_belong_symbol_tree()
|
|
571
|
+
if stree is not None and stree is not self:
|
|
572
|
+
raise RuntimeError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
|
|
573
|
+
f"current: {self.get_ori_cls_name()}.")
|
|
574
|
+
|
|
575
|
+
# Check if node is inserted between Input node
|
|
576
|
+
if base_node is not None and base_node.get_node_type() == NodeType.Input:
|
|
652
577
|
valid = True
|
|
653
|
-
if
|
|
578
|
+
if before_node:
|
|
654
579
|
valid = False
|
|
655
|
-
if
|
|
580
|
+
if base_node.get_next() is not None and base_node.get_next().get_node_type() == NodeType.Input:
|
|
656
581
|
valid = False
|
|
657
582
|
if not valid:
|
|
658
|
-
raise RuntimeError("Can not insert a node before or between parameters:",
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
self.
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
if
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
583
|
+
raise RuntimeError("Can not insert a node before or between parameters:", base_node.get_name())
|
|
584
|
+
|
|
585
|
+
# save target name, which is used to provide unique target
|
|
586
|
+
if new_node.get_targets():
|
|
587
|
+
for target in new_node.get_targets():
|
|
588
|
+
self._target_namer.add_name(str(target))
|
|
589
|
+
|
|
590
|
+
self._handle_custom_obj_in_normalized_args(new_node)
|
|
591
|
+
|
|
592
|
+
# Insert node into NodeManager
|
|
593
|
+
if node_manager is None:
|
|
594
|
+
if base_node is None:
|
|
595
|
+
raise RuntimeError("node_manager and base_node cannot both be None when inserting a node.")
|
|
596
|
+
node_manager = base_node.get_node_manager()
|
|
597
|
+
|
|
598
|
+
# set node's _belong_symbol_tree
|
|
599
|
+
new_node.set_belong_symbol_tree(self)
|
|
600
|
+
|
|
601
|
+
if node_manager is self:
|
|
602
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
603
|
+
if insert_to_ast:
|
|
604
|
+
# update init-function-ast and construct-function-ast
|
|
605
|
+
self.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
|
|
606
|
+
else:
|
|
607
|
+
node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
|
|
674
608
|
|
|
675
|
-
|
|
609
|
+
# register code changed event observer, which is used to update _modified flag.
|
|
610
|
+
if new_node.get_node_type() == NodeType.Tree:
|
|
611
|
+
new_node.symbol_tree.reg_observer(self)
|
|
612
|
+
elif isinstance(new_node, NodeManager):
|
|
613
|
+
new_node.reg_observer(self)
|
|
614
|
+
|
|
615
|
+
return new_node
|
|
616
|
+
|
|
617
|
+
def append_node(self, node: Node, node_manager: NodeManager = None, append_to_ast: bool = True) -> Node:
|
|
676
618
|
"""
|
|
677
619
|
Append a node to SymbolTree.
|
|
678
620
|
|
|
@@ -680,13 +622,17 @@ class SymbolTree(Observer, Observable):
|
|
|
680
622
|
node (Node): An instance of node to be appended.
|
|
681
623
|
append_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
682
624
|
True.
|
|
625
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
626
|
+
NodeManager of symboltree's construct function.
|
|
683
627
|
|
|
684
628
|
Returns:
|
|
685
629
|
An instance of node which has been appended to SymbolTree.
|
|
686
630
|
"""
|
|
687
|
-
|
|
631
|
+
if node_manager is None:
|
|
632
|
+
node_manager = self
|
|
633
|
+
return self.insert_node(node, node_manager.get_tail(), False, node_manager, append_to_ast)
|
|
688
634
|
|
|
689
|
-
def append_origin_field(self, node: Node) -> Node:
|
|
635
|
+
def append_origin_field(self, node: Node, node_manager: NodeManager = None) -> Node:
|
|
690
636
|
"""
|
|
691
637
|
Append an original field node to SymbolTree. An original field node represents a node created from existing
|
|
692
638
|
statement in forward method, from existing ast node in ast of forward method, so ast node do not need to update
|
|
@@ -695,18 +641,16 @@ class SymbolTree(Observer, Observable):
|
|
|
695
641
|
|
|
696
642
|
Args:
|
|
697
643
|
node (Node): An instance of node to be appended.
|
|
644
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
645
|
+
NodeManager of symboltree's construct function.
|
|
698
646
|
|
|
699
647
|
Returns:
|
|
700
648
|
An instance of node which has been appended to SymbolTree.
|
|
701
649
|
"""
|
|
702
|
-
self.
|
|
703
|
-
if node.get_node_type() == NodeType.Output:
|
|
704
|
-
self._return = node
|
|
705
|
-
elif node.get_node_type() == NodeType.Input:
|
|
706
|
-
self._inputs.append(node)
|
|
707
|
-
return self.append_node(node, False)
|
|
650
|
+
return self.append_node(node, node_manager, False)
|
|
708
651
|
|
|
709
|
-
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None
|
|
652
|
+
def append_input_node(self, ast_node, param_name: str, default: Optional[ScopedValue] = None,
|
|
653
|
+
node_manager: NodeManager = None):
|
|
710
654
|
"""
|
|
711
655
|
Append an input node to SymbolTree corresponding to parameter of forward method of network class.
|
|
712
656
|
This method is called while building SymbolTree usually.
|
|
@@ -716,13 +660,18 @@ class SymbolTree(Observer, Observable):
|
|
|
716
660
|
param_name (str): A str represents name of parameter of forward method of network class.
|
|
717
661
|
default (ScopedValue, optional): A ScopedValue represents default value of parameter. Default is None which
|
|
718
662
|
means parameter has no default value.
|
|
663
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
664
|
+
NodeManager of symboltree's construct function.
|
|
719
665
|
|
|
720
666
|
Returns:
|
|
721
667
|
An instance of input node which has been appended to SymbolTree.
|
|
722
668
|
"""
|
|
723
669
|
if param_name == "self":
|
|
724
670
|
return
|
|
725
|
-
|
|
671
|
+
# check param_name duplicated
|
|
672
|
+
if node_manager is None:
|
|
673
|
+
node_manager = self
|
|
674
|
+
for input_node in node_manager._inputs:
|
|
726
675
|
targets = input_node.get_targets()
|
|
727
676
|
if len(targets) != 1:
|
|
728
677
|
raise RuntimeError("targets should have 1 elements")
|
|
@@ -735,9 +684,10 @@ class SymbolTree(Observer, Observable):
|
|
|
735
684
|
if exist_param == param_name:
|
|
736
685
|
raise RuntimeError("input duplicated:", param_name)
|
|
737
686
|
input_node = Node.create_input_node(ast_node, param_name, default, name=f"input_{param_name}")
|
|
738
|
-
self.append_origin_field(input_node)
|
|
687
|
+
self.append_origin_field(input_node, node_manager)
|
|
739
688
|
|
|
740
|
-
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST
|
|
689
|
+
def try_append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST,
|
|
690
|
+
node_manager: NodeManager = None) -> Optional[Node]:
|
|
741
691
|
"""
|
|
742
692
|
Try appending a python node to SymbolTree if 'ast_node' is not None and 'ast_node' is not Empty if 'ast_node' is
|
|
743
693
|
a list or a dict.
|
|
@@ -746,6 +696,8 @@ class SymbolTree(Observer, Observable):
|
|
|
746
696
|
Args:
|
|
747
697
|
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
748
698
|
ast_node (ast.AST): A ast node represents ast node.
|
|
699
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
700
|
+
NodeManager of symboltree's construct function.
|
|
749
701
|
|
|
750
702
|
Returns:
|
|
751
703
|
An instance of python node if a new node has been appended to SymbolTree else None.
|
|
@@ -754,9 +706,9 @@ class SymbolTree(Observer, Observable):
|
|
|
754
706
|
return None
|
|
755
707
|
if isinstance(ast_node, (list, dict)) and not ast_node:
|
|
756
708
|
return None
|
|
757
|
-
return self.append_python_node(ast_scope, ast_node)
|
|
709
|
+
return self.append_python_node(ast_scope, ast_node, node_manager)
|
|
758
710
|
|
|
759
|
-
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST) -> Node:
|
|
711
|
+
def append_python_node(self, ast_scope: ast.AST, ast_node: ast.AST, node_manager: NodeManager = None) -> Node:
|
|
760
712
|
"""
|
|
761
713
|
Append a python node to SymbolTree.
|
|
762
714
|
This method is called while building SymbolTree usually.
|
|
@@ -764,40 +716,50 @@ class SymbolTree(Observer, Observable):
|
|
|
764
716
|
Args:
|
|
765
717
|
ast_scope (ast.AST): A ast node represents ast node of scope of node.
|
|
766
718
|
ast_node (ast.AST): A ast node represents ast node.
|
|
719
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means those asts belong to
|
|
720
|
+
NodeManager of symboltree's construct function.
|
|
767
721
|
|
|
768
722
|
Returns:
|
|
769
723
|
An instance of python node which has been appended to SymbolTree.
|
|
770
724
|
"""
|
|
771
725
|
logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
|
|
772
|
-
node_name =
|
|
773
|
-
self._update_names_for_unique(ast_node)
|
|
726
|
+
node_name = type(ast_node).__name__
|
|
774
727
|
node = Node.create_python_node(ast_node, node_name)
|
|
775
|
-
|
|
728
|
+
if node_manager is None or node_manager is self:
|
|
729
|
+
NodeManager.append_python_node(self, node)
|
|
730
|
+
else:
|
|
731
|
+
node_manager.append_python_node(node)
|
|
776
732
|
return node
|
|
777
733
|
|
|
778
|
-
def set_output(self, return_value: str,
|
|
734
|
+
def set_output(self, return_value: str, arg_index: int, return_idx: int = 0,
|
|
735
|
+
node_manager: NodeManager = None) -> Node:
|
|
779
736
|
"""
|
|
780
737
|
Update return value of return of forward method of network class.
|
|
781
738
|
|
|
782
739
|
Args:
|
|
783
740
|
return_value (str): A str represents new return value.
|
|
784
|
-
|
|
741
|
+
arg_index (int): A int indicates which value in return to be updated.
|
|
742
|
+
return_idx (int): A int indicates which return node to be updated. Default: 0.
|
|
743
|
+
node_manager (NodeManager): NodeManager those asts belong to. Default: None, means
|
|
744
|
+
symboltree's construct function.
|
|
785
745
|
|
|
786
746
|
Returns:
|
|
787
747
|
An instance of node represents return node after updated.
|
|
788
748
|
"""
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
749
|
+
node_returns = NodeManager.get_returns(self) if node_manager is None else node_manager.get_returns()
|
|
750
|
+
if not node_returns:
|
|
751
|
+
raise RuntimeError("Current node_manager has no output")
|
|
752
|
+
if return_idx >= len(node_returns):
|
|
753
|
+
raise RuntimeError(f"return_idx {return_idx} should be less than return num {len(node_returns)}.")
|
|
754
|
+
node_return = node_returns[return_idx]
|
|
755
|
+
self.set_node_arg(node_return, arg_index, return_value)
|
|
756
|
+
return node_return
|
|
793
757
|
|
|
794
758
|
def erase_node(self, node_or_name: Union[Node, str]) -> Node:
|
|
795
759
|
"""
|
|
796
760
|
Erase a node from SymbolTree.
|
|
797
|
-
Note:
|
|
798
|
-
If node is depended on by other node, RuntimeError will raise.
|
|
799
761
|
|
|
800
|
-
|
|
762
|
+
Topological relation will be updated.
|
|
801
763
|
|
|
802
764
|
Args:
|
|
803
765
|
node_or_name (Union[Node, str]): An instance of node or a str represents name of node.
|
|
@@ -813,71 +775,51 @@ class SymbolTree(Observer, Observable):
|
|
|
813
775
|
node = self._get_real_node(node_or_name)
|
|
814
776
|
if node is None:
|
|
815
777
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
778
|
+
# erase node in NodeManager
|
|
779
|
+
node_manager = node.get_node_manager()
|
|
780
|
+
|
|
781
|
+
logger.debug(f"[earse]stree: {self.get_opt_cls_name()}, "
|
|
782
|
+
f"node_manager: {node_manager.get_manager_name()}, "
|
|
783
|
+
f"code: {astunparse.unparse(node.get_ast()).strip()}, "
|
|
784
|
+
f"node_name:{node.get_name()}")
|
|
785
|
+
|
|
786
|
+
if node_manager is self:
|
|
787
|
+
NodeManager.erase_node(self, node)
|
|
788
|
+
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
789
|
+
if not ret:
|
|
790
|
+
raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
791
|
+
else:
|
|
792
|
+
node_manager.erase_node(node)
|
|
829
793
|
self._deleted_node.append(node.get_name())
|
|
830
794
|
return node
|
|
831
795
|
|
|
832
796
|
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
|
833
797
|
"""
|
|
834
|
-
Replace an old_node with a
|
|
835
|
-
Note:
|
|
836
|
-
Rewrite will iterate all nodes linked to this root node and insert these nodes into symbol_tree.
|
|
837
|
-
|
|
838
|
-
Inputs of intra sub-tree nodes need to be welly set.
|
|
839
|
-
|
|
840
|
-
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
|
798
|
+
Replace an old_node with a node list.
|
|
841
799
|
|
|
842
800
|
Args:
|
|
843
801
|
old_node (Node): Node to be replaced.
|
|
844
|
-
new_nodes (list[Node]): Node
|
|
802
|
+
new_nodes (list[Node]): Node list to replace in.
|
|
845
803
|
|
|
846
804
|
Returns:
|
|
847
|
-
|
|
805
|
+
Last node in new_nodes list.
|
|
848
806
|
|
|
849
807
|
Raises:
|
|
850
808
|
RuntimeError: If 'old_node' is isolated.
|
|
851
809
|
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
|
852
810
|
"""
|
|
853
|
-
|
|
854
|
-
if hasattr(old_node, "container"):
|
|
855
|
-
self._replace_container_node(old_node, new_nodes)
|
|
856
|
-
return new_nodes[0]
|
|
857
811
|
real_old_node = self._get_real_node(old_node)
|
|
858
812
|
if real_old_node is None:
|
|
859
813
|
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
|
860
|
-
#
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
else:
|
|
868
|
-
position = self.after(prev_node)
|
|
869
|
-
# insert node first, because targets of new_node is determined after insert
|
|
870
|
-
new_tree_root = SymbolTree._link_nodes_and_find_root(new_nodes)
|
|
871
|
-
new_node = self._insert_tree(position, new_tree_root)
|
|
872
|
-
# use targets of insert tree to redirect edge
|
|
873
|
-
users = self.get_node_users(old_node)
|
|
874
|
-
if len(new_node.get_targets()) != 1:
|
|
875
|
-
raise RuntimeError("targets of new_node should have 1 elements")
|
|
876
|
-
for user in users:
|
|
877
|
-
self.set_node_arg_by_node(user[0], user[1], new_node)
|
|
878
|
-
# erase old_node after edge is redirected because node can be erased only when node is isolated topologically
|
|
814
|
+
# insert new_nodes into node_manager
|
|
815
|
+
node_manager = real_old_node.get_node_manager()
|
|
816
|
+
# insert new_nodes into NodeManager
|
|
817
|
+
base_node = old_node
|
|
818
|
+
for node in new_nodes:
|
|
819
|
+
self.insert_node(node, base_node, False, node_manager, True)
|
|
820
|
+
base_node = node
|
|
879
821
|
self.erase_node(old_node)
|
|
880
|
-
return
|
|
822
|
+
return new_nodes[-1]
|
|
881
823
|
|
|
882
824
|
def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
|
|
883
825
|
"""
|
|
@@ -933,30 +875,234 @@ class SymbolTree(Observer, Observable):
|
|
|
933
875
|
if out_idx >= len(targets):
|
|
934
876
|
raise RuntimeError("out_idx out of range: ", out_idx)
|
|
935
877
|
new_arg = targets[out_idx]
|
|
936
|
-
|
|
878
|
+
real_dst_node.set_arg(new_arg, arg_idx)
|
|
879
|
+
self._topo_mgr.on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
|
|
880
|
+
|
|
881
|
+
def unique_name(self, name: str):
|
|
882
|
+
"""Get a unique name in the symboltree"""
|
|
883
|
+
return self._target_namer.get_name(name)
|
|
884
|
+
|
|
885
|
+
def unique_func_name(self, name: str):
|
|
886
|
+
"""Get a unique function name in the symboltree"""
|
|
887
|
+
if not hasattr(self._origin_network, name):
|
|
888
|
+
return name
|
|
889
|
+
suffix = 1
|
|
890
|
+
while hasattr(self._origin_network, f"{name}_{suffix}"):
|
|
891
|
+
suffix += 1
|
|
892
|
+
return f"{name}_{suffix}"
|
|
893
|
+
|
|
894
|
+
def set_node_target(self, node: Union[Node, str], index: int, target: Union[ScopedValue, str]):
|
|
895
|
+
"""
|
|
896
|
+
Set target of `node` .
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
node (Union[Node, str]): Node to be modified. Can be a node or name of node.
|
|
900
|
+
index (int): Indicate which target being modified.
|
|
901
|
+
arg (Union[ScopedValue, str]): New target to been set.
|
|
937
902
|
|
|
938
|
-
|
|
903
|
+
Raises:
|
|
904
|
+
ValueError: If `node` is not belong to current SymbolTree.
|
|
905
|
+
ValueError: If index of `node` 's target is greater than number of targets.
|
|
906
|
+
"""
|
|
907
|
+
|
|
908
|
+
real_node = self._get_real_node(node)
|
|
909
|
+
if real_node is None:
|
|
910
|
+
raise ValueError("Node is not belong to current SymbolTree: ", node)
|
|
911
|
+
if isinstance(target, str):
|
|
912
|
+
target = ScopedValue.create_naming_value(target)
|
|
913
|
+
targets = node.get_targets()
|
|
914
|
+
if index >= len(targets):
|
|
915
|
+
raise ValueError(f"Index of node's target should be less than {len(targets)}, but got {index}")
|
|
916
|
+
old_target = targets[index]
|
|
917
|
+
targets[index] = target
|
|
918
|
+
node.set_targets(targets)
|
|
919
|
+
self._topo_mgr.on_update_target(node, index, old_target, target)
|
|
920
|
+
|
|
921
|
+
def all_nodes(self):
|
|
922
|
+
"""
|
|
923
|
+
Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
A list of nodes.
|
|
927
|
+
"""
|
|
928
|
+
nodes = []
|
|
929
|
+
node_managers = [self]
|
|
930
|
+
while node_managers:
|
|
931
|
+
node_manager = node_managers.pop()
|
|
932
|
+
nodes.extend(node_manager.nodes())
|
|
933
|
+
for node in node_manager.nodes():
|
|
934
|
+
if isinstance(node, NodeManager):
|
|
935
|
+
node_managers.append(node)
|
|
936
|
+
for tree_node in self.get_tree_nodes():
|
|
937
|
+
stree = tree_node.symbol_tree
|
|
938
|
+
nodes.extend(stree.all_nodes())
|
|
939
|
+
return nodes
|
|
940
|
+
|
|
941
|
+
def get_node_from_name(self, node_name: str):
|
|
942
|
+
"""
|
|
943
|
+
Get node from all NodeManagers in current symbol tree by `node_name`.
|
|
944
|
+
|
|
945
|
+
Args:
|
|
946
|
+
node_name (str): A str represents name of node as key of query.
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
An instance of Node if found else None.
|
|
950
|
+
"""
|
|
951
|
+
node_managers = [self]
|
|
952
|
+
while node_managers:
|
|
953
|
+
node_manager = node_managers.pop()
|
|
954
|
+
node = node_manager.get_node(node_name)
|
|
955
|
+
if node:
|
|
956
|
+
return node
|
|
957
|
+
for node in node_manager.nodes():
|
|
958
|
+
if isinstance(node, NodeManager):
|
|
959
|
+
node_managers.append(node)
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
def print_node_tabulate(self, all_nodes: bool = False):
|
|
963
|
+
"""
|
|
964
|
+
Print nodes information and nodes' topological relations.
|
|
965
|
+
|
|
966
|
+
Args:
|
|
967
|
+
all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
|
|
968
|
+
nodes, CellContainer nodes and sub symbol trees.
|
|
969
|
+
"""
|
|
939
970
|
try:
|
|
940
|
-
from tabulate import tabulate
|
|
971
|
+
from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
|
|
941
972
|
except ImportError:
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
973
|
+
logger.warning("print_node_tabulate relies on the library `tabulate`, "
|
|
974
|
+
"which could not be found on this machine. Run `pip "
|
|
975
|
+
"install tabulate` to install the library.")
|
|
976
|
+
return ""
|
|
977
|
+
print(NodeManager.dump(self, self.get_manager_name()))
|
|
978
|
+
if all_nodes:
|
|
979
|
+
node_managers = [self]
|
|
980
|
+
while node_managers:
|
|
981
|
+
node_manager = node_managers.pop()
|
|
982
|
+
for node in node_manager.nodes():
|
|
983
|
+
if isinstance(node, NodeManager):
|
|
984
|
+
print(node.dump(node.get_manager_name()))
|
|
985
|
+
node_managers.append(node)
|
|
986
|
+
for tree_node in self.get_tree_nodes():
|
|
987
|
+
stree = tree_node.symbol_tree
|
|
988
|
+
stree.print_node_tabulate(all_nodes)
|
|
948
989
|
|
|
949
990
|
def dump(self):
|
|
950
991
|
"""Dump graph."""
|
|
951
992
|
dump_st = SymbolTreeDumper(self)
|
|
952
993
|
dump_st.dump()
|
|
953
994
|
|
|
954
|
-
def
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
self.
|
|
995
|
+
def check_body_exist(self, body, code_bodies):
|
|
996
|
+
"""Check whether body already exist in code_bodies"""
|
|
997
|
+
# Check import ast node exist by saving import code string to self._tmp_import_strs
|
|
998
|
+
if isinstance(body, (ast.Import, ast.ImportFrom, ast.Expr)):
|
|
999
|
+
import_str = astunparse.unparse(body)
|
|
1000
|
+
if import_str in self._tmp_import_strs:
|
|
1001
|
+
return True
|
|
1002
|
+
self._tmp_import_strs.append(import_str)
|
|
1003
|
+
return False
|
|
1004
|
+
|
|
1005
|
+
# Check ClassDef ast node exist by using AstClassFinder
|
|
1006
|
+
if isinstance(body, ast.ClassDef):
|
|
1007
|
+
if sys.version_info >= (3, 9):
|
|
1008
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1009
|
+
else:
|
|
1010
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies))
|
|
1011
|
+
results = class_finder.find_all(body.name)
|
|
1012
|
+
return bool(results)
|
|
1013
|
+
|
|
1014
|
+
# Check FunctionDef ast node exist by using AstFunctionFinder
|
|
1015
|
+
if isinstance(body, ast.FunctionDef):
|
|
1016
|
+
if sys.version_info >= (3, 9):
|
|
1017
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1018
|
+
else:
|
|
1019
|
+
function_finder = AstFunctionFinder(ast.Module(body=code_bodies))
|
|
1020
|
+
results = function_finder.find_all(body.name)
|
|
1021
|
+
return bool(results)
|
|
1022
|
+
|
|
1023
|
+
return False
|
|
1024
|
+
|
|
1025
|
+
def update_class_name_of_unmodified_stree(self, stree, code_bodies) -> bool:
|
|
1026
|
+
"""
|
|
1027
|
+
For the unmodified symbol tree, only one definition code remains in the generated code.
|
|
1028
|
+
Everywhere else calling this symbol tree will use the class in this definition code.
|
|
1029
|
+
"""
|
|
1030
|
+
# all modified ast.ClassDef will be exported to code
|
|
1031
|
+
if stree.is_modified():
|
|
1032
|
+
return False
|
|
1033
|
+
# all un-modified ast.ClassDef only keep one instance
|
|
1034
|
+
first_cls_name = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
|
|
1035
|
+
if first_cls_name is None:
|
|
1036
|
+
class_ast = stree.get_class_ast()
|
|
1037
|
+
if class_ast:
|
|
1038
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())] = class_ast.name
|
|
1039
|
+
return False
|
|
1040
|
+
# Un-modified ast.ClassDef already exist in code_bodies,
|
|
1041
|
+
# replace class name to class name of first un-modified ast.ClassDef.
|
|
1042
|
+
if sys.version_info >= (3, 9):
|
|
1043
|
+
replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1044
|
+
else:
|
|
1045
|
+
replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1046
|
+
replacer.replace_all(stree.get_class_ast().name, first_cls_name)
|
|
1047
|
+
self._tmp_replacers.append(replacer)
|
|
1048
|
+
return True
|
|
1049
|
+
|
|
1050
|
+
def convert_stree_to_code_bodies(self, stree, code_bodies, insert_pos=0):
|
|
1051
|
+
"""
|
|
1052
|
+
Convert nodes in stree to code_bodies
|
|
1053
|
+
|
|
1054
|
+
1. Add import asts into code_bodies
|
|
1055
|
+
2. Add class, function and other type of asts into code_bodies
|
|
1056
|
+
3. Add father class asts into code_bodies
|
|
1057
|
+
4. Add external function asts into code_bodies
|
|
1058
|
+
5. Add subtrees to code_bodies
|
|
1059
|
+
5.1 Add subtrees in construct to code_bodies
|
|
1060
|
+
5.2 Add subtrees in CellContainers to code_bodies
|
|
1061
|
+
|
|
1062
|
+
"""
|
|
1063
|
+
# Add import asts into code_bodies
|
|
1064
|
+
for body in stree.get_import_asts():
|
|
1065
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1066
|
+
code_bodies.insert(insert_pos, body)
|
|
1067
|
+
insert_pos += 1
|
|
1068
|
+
|
|
1069
|
+
# Add class, function and other type of asts into code_bodies
|
|
1070
|
+
if stree.get_module_ast():
|
|
1071
|
+
for body in stree.get_module_ast().body:
|
|
1072
|
+
if self.check_body_exist(body, code_bodies):
|
|
1073
|
+
continue
|
|
1074
|
+
if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
|
|
1075
|
+
code_bodies.insert(insert_pos, body)
|
|
1076
|
+
else:
|
|
1077
|
+
code_bodies.append(body)
|
|
1078
|
+
|
|
1079
|
+
# Add father class asts into code_bodies
|
|
1080
|
+
for body in reversed(stree.get_father_class_ast()):
|
|
1081
|
+
if self.check_body_exist(body, code_bodies):
|
|
1082
|
+
# remove exist ast in old position, then insert ast to upper position
|
|
1083
|
+
if sys.version_info >= (3, 9):
|
|
1084
|
+
exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
|
|
1085
|
+
else:
|
|
1086
|
+
exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
|
|
1087
|
+
code_bodies.remove(exist_ast)
|
|
1088
|
+
code_bodies.insert(insert_pos, body)
|
|
1089
|
+
|
|
1090
|
+
# Add external asts into code_bodies
|
|
1091
|
+
for body in stree.get_external_ast():
|
|
1092
|
+
if not self.check_body_exist(body, code_bodies):
|
|
1093
|
+
code_bodies.insert(insert_pos, body)
|
|
1094
|
+
insert_pos += 1
|
|
1095
|
+
|
|
1096
|
+
# Add subtrees to code_bodies
|
|
1097
|
+
for node in stree.get_tree_nodes():
|
|
1098
|
+
sub_stree = node.symbol_tree
|
|
1099
|
+
# Ignore TreeNode create by function in the class
|
|
1100
|
+
if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
|
|
1101
|
+
continue
|
|
1102
|
+
# For the unmodified class, update class name to name of first class
|
|
1103
|
+
if self.update_class_name_of_unmodified_stree(sub_stree, code_bodies):
|
|
1104
|
+
continue
|
|
1105
|
+
self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, insert_pos)
|
|
960
1106
|
|
|
961
1107
|
def get_code(self) -> str:
|
|
962
1108
|
"""
|
|
@@ -965,34 +1111,22 @@ class SymbolTree(Observer, Observable):
|
|
|
965
1111
|
Returns:
|
|
966
1112
|
A str represents source code of modified network.
|
|
967
1113
|
"""
|
|
968
|
-
self.
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
self.
|
|
1114
|
+
self._tmp_import_strs.clear()
|
|
1115
|
+
self._tmp_unmodified_strees.clear()
|
|
1116
|
+
self._tmp_replacers.clear()
|
|
1117
|
+
code_bodies = []
|
|
1118
|
+
self.convert_stree_to_code_bodies(self, code_bodies)
|
|
1119
|
+
if sys.version_info >= (3, 9):
|
|
1120
|
+
gencode_module = ast.Module(body=code_bodies, type_ignores=[])
|
|
1121
|
+
else:
|
|
1122
|
+
gencode_module = ast.Module(body=code_bodies)
|
|
1123
|
+
SymbolTree._remove_unused_import(gencode_module)
|
|
1124
|
+
SymbolTree._remove_duplicated_import(gencode_module)
|
|
973
1125
|
ast.fix_missing_locations(self._module_ast)
|
|
974
|
-
|
|
975
|
-
# Replace duplicated ast.ClassDef reference in main-ClassDef
|
|
976
|
-
seen_class: {type, str} = {}
|
|
977
|
-
allow_class_name = [self._class_ast.name]
|
|
978
|
-
replacers = []
|
|
979
|
-
SymbolTree._find_all_class_in_symboltree(self, seen_class, allow_class_name, replacers)
|
|
980
|
-
# Add all non-ClassDef body to gencode_module
|
|
981
|
-
# Add all ClassDef in allow_class_name to gencode_module
|
|
982
|
-
# Use gencode_module to generate code
|
|
983
|
-
bodies = []
|
|
984
|
-
for body in self._module_ast.body:
|
|
985
|
-
if not isinstance(body, ast.ClassDef):
|
|
986
|
-
bodies.append(body)
|
|
987
|
-
continue
|
|
988
|
-
if body.name in allow_class_name:
|
|
989
|
-
bodies.append(body)
|
|
990
|
-
gencode_module = ast.Module(body=bodies)
|
|
991
|
-
if_fixer = IfFixer()
|
|
992
|
-
if_fixer.fix(gencode_module)
|
|
1126
|
+
IfFixer().fix(gencode_module)
|
|
993
1127
|
code = astunparse.unparse(gencode_module)
|
|
994
|
-
#
|
|
995
|
-
for replacer in
|
|
1128
|
+
# Revert the class name to its original state
|
|
1129
|
+
for replacer in self._tmp_replacers:
|
|
996
1130
|
replacer.undo_all()
|
|
997
1131
|
return code
|
|
998
1132
|
|
|
@@ -1026,305 +1160,71 @@ class SymbolTree(Observer, Observable):
|
|
|
1026
1160
|
f.write(source.encode('utf-8'))
|
|
1027
1161
|
f.flush()
|
|
1028
1162
|
|
|
1029
|
-
def
|
|
1030
|
-
|
|
1031
|
-
if isinstance(node, ast.Call):
|
|
1032
|
-
self.update_scope_for_unique(node.func)
|
|
1033
|
-
return
|
|
1034
|
-
if not isinstance(node, (ast.Attribute, ast.Subscript)):
|
|
1035
|
-
logger.warning(f"Cannot update node {astunparse.unparse(node)} for unique, type of node should "
|
|
1036
|
-
f"be one of (ast.Attribute, ast.Subscript).")
|
|
1037
|
-
return
|
|
1038
|
-
scope = node.value
|
|
1039
|
-
if not isinstance(scope, ast.Name):
|
|
1040
|
-
self.update_scope_for_unique(scope)
|
|
1041
|
-
return
|
|
1042
|
-
scope_name = scope.id
|
|
1043
|
-
scope_name_unique = self._target_namer.get_real_arg(scope_name)
|
|
1044
|
-
scope.id = scope_name_unique
|
|
1045
|
-
|
|
1046
|
-
def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
|
|
1163
|
+
def insert_to_ast_while_insert_node(self, new_node: Node, base_node: Node, before_node: bool,
|
|
1164
|
+
node_manager: NodeManager):
|
|
1047
1165
|
""" insert_to_ast_while_insert_node. """
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
1056
|
-
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
1057
|
-
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
1058
|
-
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
1059
|
-
|
|
1060
|
-
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
1061
|
-
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
1062
|
-
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
1063
|
-
|
|
1064
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1065
|
-
None if position is None else position.node.get_ast(),
|
|
1066
|
-
position.before_node)
|
|
1067
|
-
sub_stree: SymbolTree = node.symbol_tree
|
|
1068
|
-
from .symbol_tree_builder import SymbolTreeBuilder
|
|
1069
|
-
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
1166
|
+
if new_node.get_node_type() == NodeType.Input:
|
|
1167
|
+
# insert a new input
|
|
1168
|
+
self._inputs.append(new_node)
|
|
1169
|
+
ast_construct = self.get_ast_root()
|
|
1170
|
+
arg: str = new_node.get_targets()[0].value
|
|
1171
|
+
ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
|
|
1172
|
+
AstModifier.append_arg_to_function(ast_construct, ast_arg)
|
|
1070
1173
|
else:
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
body = self._module_ast.body[i]
|
|
1086
|
-
if not isinstance(body, (ast.Import, ast.ImportFrom)):
|
|
1087
|
-
continue
|
|
1088
|
-
if isinstance(body, ast.Import):
|
|
1089
|
-
continue
|
|
1090
|
-
if isinstance(body, ast.ImportFrom) and body.module == "cell":
|
|
1091
|
-
self._module_ast.body.remove(body)
|
|
1092
|
-
continue
|
|
1093
|
-
for alias in body.names:
|
|
1094
|
-
name = alias.asname if alias.asname else alias.name
|
|
1095
|
-
if not str_checker.check(name):
|
|
1096
|
-
if len(body.names) == 1:
|
|
1097
|
-
self._module_ast.body.remove(body)
|
|
1098
|
-
i += 1
|
|
1099
|
-
else:
|
|
1100
|
-
body.names.remove(alias)
|
|
1101
|
-
|
|
1102
|
-
def _replace_container_node(self, old_node, new_nodes):
|
|
1103
|
-
cellcontainer = getattr(old_node, "container")
|
|
1104
|
-
index = cellcontainer.node_list.index(old_node)
|
|
1105
|
-
for n in reversed(new_nodes):
|
|
1106
|
-
cellcontainer.insert(index, n)
|
|
1107
|
-
index = cellcontainer.node_list.index(old_node)
|
|
1108
|
-
cellcontainer.erase(old_node)
|
|
1109
|
-
|
|
1110
|
-
def _filter_out_to_delete_field(self, to_delete_field):
|
|
1111
|
-
"""filter out used field from `to_delete_field`"""
|
|
1112
|
-
for func_def in self._class_ast.body:
|
|
1113
|
-
if not isinstance(func_def, ast.FunctionDef):
|
|
1114
|
-
continue
|
|
1115
|
-
if func_def.name != "__init__":
|
|
1116
|
-
to_delete_to_delete_keys = []
|
|
1117
|
-
property_checker = CheckPropertyIsUsed(func_def)
|
|
1118
|
-
for key, _ in self._deleted_field.items():
|
|
1119
|
-
if property_checker.check("self", key):
|
|
1120
|
-
to_delete_to_delete_keys.append(key)
|
|
1121
|
-
property_checker = CheckPropertyIsUsed(func_def)
|
|
1122
|
-
for key in to_delete_to_delete_keys:
|
|
1123
|
-
self._deleted_field.pop(key)
|
|
1174
|
+
# insert a new assign statement
|
|
1175
|
+
ast_assign = new_node.get_ast()
|
|
1176
|
+
if ast_assign is None:
|
|
1177
|
+
func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
|
|
1178
|
+
new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
|
|
1179
|
+
ast_assign = new_node.update_ast_node()
|
|
1180
|
+
if not isinstance(ast_assign, ast.Assign):
|
|
1181
|
+
raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
|
|
1182
|
+
# Save instance into _origin_network.
|
|
1183
|
+
setattr(self._origin_network, new_node.get_name(), new_node.get_instance())
|
|
1184
|
+
# Insert ast to __init__ function
|
|
1185
|
+
if isinstance(new_node, TreeNode):
|
|
1186
|
+
init_code = f"self.{new_node.get_name()} = " \
|
|
1187
|
+
f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
|
|
1124
1188
|
else:
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
self._deleted_field.pop(key)
|
|
1136
|
-
|
|
1137
|
-
def _remove_unused_field(self):
|
|
1138
|
-
"""remove unused field in __init__ function"""
|
|
1139
|
-
multi_targets = []
|
|
1140
|
-
for index, body in enumerate(self._init_func_ast.body):
|
|
1141
|
-
if not isinstance(body, ast.Assign):
|
|
1142
|
-
continue
|
|
1143
|
-
targets = body.targets
|
|
1144
|
-
for target in targets:
|
|
1145
|
-
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
|
|
1146
|
-
and target.value.id == "self":
|
|
1147
|
-
self._deleted_field[target.attr] = index
|
|
1148
|
-
if len(targets) > 1:
|
|
1149
|
-
multi_targets.append(index)
|
|
1150
|
-
self._filter_out_to_delete_field(self._deleted_field)
|
|
1151
|
-
for i in range(len(self._init_func_ast.body) - 1, -1, -1):
|
|
1152
|
-
if i in self._deleted_field.values():
|
|
1153
|
-
if i in multi_targets:
|
|
1154
|
-
raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
|
|
1155
|
-
AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
|
|
1156
|
-
ast.fix_missing_locations(self._init_func_ast)
|
|
1157
|
-
|
|
1158
|
-
def _remove_duplicated_import(self):
|
|
1159
|
-
"""Remove duplicated import of 'net'."""
|
|
1160
|
-
imports = []
|
|
1161
|
-
for body in self._module_ast.body:
|
|
1162
|
-
if isinstance(body, (ast.ImportFrom, ast.Import)):
|
|
1163
|
-
import_str = astunparse.unparse(body)
|
|
1164
|
-
if import_str not in imports:
|
|
1165
|
-
imports.append(import_str)
|
|
1166
|
-
else:
|
|
1167
|
-
self._module_ast.body.remove(body)
|
|
1189
|
+
init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
|
|
1190
|
+
init_ast = ast.parse(init_code).body[0]
|
|
1191
|
+
AstModifier.insert_assign_ast_to_function(self._init_func_ast, init_ast)
|
|
1192
|
+
# Insert ast to construct_function/class_internal_function
|
|
1193
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
1194
|
+
ast_functiondef = node_manager.get_ast_functiondef()
|
|
1195
|
+
if not ast_functiondef:
|
|
1196
|
+
raise RuntimeError(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
|
|
1197
|
+
"when inserting the ast.")
|
|
1198
|
+
AstModifier.insert_assign_ast_to_function(ast_functiondef, ast_assign, ast_base_node, before_node)
|
|
1168
1199
|
|
|
1169
1200
|
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1170
1201
|
if isinstance(node_or_name, str):
|
|
1171
1202
|
return self.get_node(node_or_name)
|
|
1172
1203
|
return node_or_name
|
|
1173
1204
|
|
|
1174
|
-
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
|
1175
|
-
"""
|
|
1176
|
-
Insert a node-tree into SymbolTree.
|
|
1177
|
-
Note:
|
|
1178
|
-
Inputs of intra sub-tree nodes need to be welly set.
|
|
1179
|
-
|
|
1180
|
-
Inputs of inter sub-tree nodes will be updated by Rewrite automatically.
|
|
1181
|
-
|
|
1182
|
-
Args:
|
|
1183
|
-
position (Position): A Position indicates an insert position point.
|
|
1184
|
-
root (Node): An instance of node as root of node-tree to be inserted in.
|
|
1185
|
-
insert_to_ast (bool): A bool indicates whether to update corresponding ast node at same time, default is
|
|
1186
|
-
True.
|
|
1187
|
-
|
|
1188
|
-
Returns:
|
|
1189
|
-
An instance of node as root node of node-tree which has been inserted into SymbolTree.
|
|
1190
|
-
|
|
1191
|
-
Raises:
|
|
1192
|
-
RuntimeError: If 'position' is not in current SymbolTree.
|
|
1193
|
-
"""
|
|
1194
|
-
|
|
1195
|
-
# if position not in current SymbolTree
|
|
1196
|
-
if position.symbol_tree is not self:
|
|
1197
|
-
raise RuntimeError("Position is not in current SymbolTree: ", position)
|
|
1198
|
-
|
|
1199
|
-
queue: [Node] = [root]
|
|
1200
|
-
todos: [] = []
|
|
1201
|
-
inputs_list: [] = []
|
|
1202
|
-
while queue:
|
|
1203
|
-
cur_node = queue.pop(0)
|
|
1204
|
-
if cur_node in todos:
|
|
1205
|
-
continue
|
|
1206
|
-
todos.append(cur_node)
|
|
1207
|
-
node_inputs = cur_node.get_inputs()
|
|
1208
|
-
inputs_list.append(node_inputs)
|
|
1209
|
-
for node_input in node_inputs:
|
|
1210
|
-
if node_input is not None:
|
|
1211
|
-
queue.append(node_input)
|
|
1212
|
-
todos.reverse()
|
|
1213
|
-
inputs_list.reverse()
|
|
1214
|
-
for index, todo in enumerate(todos):
|
|
1215
|
-
self.insert_node(position, todo, insert_to_ast)
|
|
1216
|
-
position = self.after(todo)
|
|
1217
|
-
# relink input of node
|
|
1218
|
-
original_inputs = inputs_list[index]
|
|
1219
|
-
for arg_idx, original_input in enumerate(original_inputs):
|
|
1220
|
-
if original_input is not None:
|
|
1221
|
-
self.set_node_arg_by_node(todo, arg_idx, original_input)
|
|
1222
|
-
return root
|
|
1223
|
-
|
|
1224
|
-
def _unique_targets(self, node: Node):
|
|
1225
|
-
"""
|
|
1226
|
-
Unique targets of node by _target_namer.
|
|
1227
|
-
|
|
1228
|
-
Args:
|
|
1229
|
-
node (Node): A Node whose targets to be uniqued.
|
|
1230
|
-
"""
|
|
1231
|
-
new_targets: [ScopedValue] = []
|
|
1232
|
-
if node.get_targets() is None:
|
|
1233
|
-
return
|
|
1234
|
-
for target in node.get_targets():
|
|
1235
|
-
if not isinstance(target, ScopedValue):
|
|
1236
|
-
raise TypeError("target should be ScopedValue, got: ", type(target))
|
|
1237
|
-
unique_target = self._target_namer.get_name(target.value)
|
|
1238
|
-
new_targets.append(ScopedValue.create_naming_value(unique_target, target.scope))
|
|
1239
|
-
node.set_targets(new_targets)
|
|
1240
|
-
|
|
1241
|
-
def _update_args_kwargs_for_unique(self, node: Node):
|
|
1242
|
-
"""
|
|
1243
|
-
Update arguments and keyword arguments of node because unique-ing of targets of other nodes.
|
|
1244
|
-
|
|
1245
|
-
Args:
|
|
1246
|
-
node (Node): A Node whose arguments and keyword arguments to be updated.
|
|
1247
|
-
"""
|
|
1248
|
-
result: {str: ScopedValue} = {}
|
|
1249
|
-
if node.get_normalized_args() is None:
|
|
1250
|
-
return
|
|
1251
|
-
for key, arg in node.get_normalized_args().items():
|
|
1252
|
-
if not isinstance(arg, ScopedValue):
|
|
1253
|
-
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
|
1254
|
-
if arg.type == ValueType.NamingValue:
|
|
1255
|
-
# unique name
|
|
1256
|
-
new_arg = ScopedValue(arg.type, arg.scope, self._target_namer.get_real_arg(arg.value))
|
|
1257
|
-
result[key] = new_arg
|
|
1258
|
-
else:
|
|
1259
|
-
result[key] = arg
|
|
1260
|
-
node.set_normalized_args(result)
|
|
1261
|
-
|
|
1262
|
-
def _add_node2nodes(self, node: Node):
|
|
1263
|
-
"""
|
|
1264
|
-
Add `node` to `_nodes` dict.
|
|
1265
|
-
|
|
1266
|
-
Args:
|
|
1267
|
-
node (Node): A Node to be added into `_nodes`.
|
|
1268
|
-
|
|
1269
|
-
Raises:
|
|
1270
|
-
RuntimeError: If name of the node is duplicated.
|
|
1271
|
-
"""
|
|
1272
|
-
node_name = node.get_name()
|
|
1273
|
-
if self._nodes.get(node_name) is not None:
|
|
1274
|
-
raise RuntimeError("generated duplicated node name", node_name, self._nodes.get(node_name),
|
|
1275
|
-
node)
|
|
1276
|
-
self._nodes[node_name] = node
|
|
1277
|
-
|
|
1278
|
-
def _insert_node(self, position: Optional[Position], node: Node):
|
|
1279
|
-
"""
|
|
1280
|
-
Insert a node into SymbolTree.
|
|
1281
|
-
1. Add `node` to `_nodes`.
|
|
1282
|
-
2. Insert `node` to node list(source code order).
|
|
1283
|
-
3. Update topological relation and update inputs of `node`.
|
|
1284
|
-
|
|
1285
|
-
Args:
|
|
1286
|
-
position ([Position, optional]): Indicates node insert position. Position is None when inserting first node
|
|
1287
|
-
of SymbolTree.
|
|
1288
|
-
node (Node): A Node to be inserted into SymbolTree.
|
|
1289
|
-
|
|
1290
|
-
Raises:
|
|
1291
|
-
RuntimeError: Position is None when _nodes of SymbolTree is not Empty. It means position can not be None
|
|
1292
|
-
unless inserting first node.
|
|
1293
|
-
"""
|
|
1294
|
-
if position is None:
|
|
1295
|
-
if self._nodes:
|
|
1296
|
-
raise RuntimeError("self._nodes should be empty")
|
|
1297
|
-
self._head = node
|
|
1298
|
-
else:
|
|
1299
|
-
if position.before_node:
|
|
1300
|
-
position.node.insert_before(node)
|
|
1301
|
-
else:
|
|
1302
|
-
position.node.insert_after(node)
|
|
1303
|
-
self._tail = node
|
|
1304
|
-
self._add_node2nodes(node)
|
|
1305
|
-
self._topo_mgr.on_insert_node(node)
|
|
1306
|
-
node.set_belong_symbol_tree(self)
|
|
1307
|
-
|
|
1308
1205
|
def _handle_custom_obj_in_normalized_args(self, node: Node):
|
|
1309
1206
|
"""
|
|
1310
|
-
Convert CustomObjValue type argument to NamingValue type argument by storing custom object
|
|
1207
|
+
Convert CustomObjValue type argument to NamingValue type argument by storing custom object to obj.
|
|
1311
1208
|
|
|
1312
1209
|
Args:
|
|
1313
1210
|
node (Node): A Node whose arguments and keyword arguments to be handled.
|
|
1314
1211
|
"""
|
|
1315
|
-
|
|
1316
|
-
for
|
|
1212
|
+
normalized_args: {str, ScopedValue} = {}
|
|
1213
|
+
for key, value in node.get_normalized_args().items():
|
|
1317
1214
|
if not isinstance(value, ScopedValue):
|
|
1318
1215
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
1319
1216
|
if value.type == ValueType.CustomObjValue:
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1217
|
+
# Save CustomObjValue into _origin_network(i.e. obj): obj.arg_name = CustomObjValue
|
|
1218
|
+
arg_name = self.unique_name(f"arg_{type(value.value).__name__}")
|
|
1219
|
+
setattr(self._origin_network, arg_name, value.value)
|
|
1220
|
+
# Add new code to __init__(): self.arg_name = obj.arg_name
|
|
1221
|
+
new_ast = ast.parse(f"self.{arg_name} = obj.{arg_name}").body[0]
|
|
1222
|
+
self._init_func_ast.body.append(new_ast)
|
|
1223
|
+
# Modify node's normalized_args: CustomObjValue -> self.arg_name
|
|
1224
|
+
normalized_args[key] = ScopedValue.create_naming_value(arg_name, "self")
|
|
1325
1225
|
else:
|
|
1326
|
-
|
|
1327
|
-
node.set_normalized_args(
|
|
1226
|
+
normalized_args[key] = value
|
|
1227
|
+
node.set_normalized_args(normalized_args)
|
|
1328
1228
|
|
|
1329
1229
|
def _get_cls_through_file(self):
|
|
1330
1230
|
"""
|
|
@@ -1336,12 +1236,14 @@ class SymbolTree(Observer, Observable):
|
|
|
1336
1236
|
Returns:
|
|
1337
1237
|
A class handle.
|
|
1338
1238
|
"""
|
|
1339
|
-
self._update_container()
|
|
1340
1239
|
file_path = os.getcwd()
|
|
1341
1240
|
file_path = os.path.join(file_path, "rewritten_network")
|
|
1342
1241
|
if not os.path.exists(file_path):
|
|
1343
|
-
|
|
1344
|
-
|
|
1242
|
+
try:
|
|
1243
|
+
os.mkdir(file_path, mode=0o700)
|
|
1244
|
+
except FileExistsError:
|
|
1245
|
+
pass
|
|
1246
|
+
file_name = f"{self._opt_cls_name}_{id(self)}.py"
|
|
1345
1247
|
network_file = os.path.join(file_path, file_name)
|
|
1346
1248
|
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1347
1249
|
source = self.get_code()
|
|
@@ -1355,15 +1257,20 @@ class SymbolTree(Observer, Observable):
|
|
|
1355
1257
|
|
|
1356
1258
|
i = 0
|
|
1357
1259
|
while not tmp_module:
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1260
|
+
spec = importlib.util.spec_from_file_location(tmp_module_name, network_file)
|
|
1261
|
+
if spec:
|
|
1262
|
+
tmp_module = importlib.util.module_from_spec(spec)
|
|
1263
|
+
spec.loader.exec_module(tmp_module)
|
|
1264
|
+
else:
|
|
1265
|
+
logger.warning(f"load module {tmp_module_name} failed, retrying.")
|
|
1361
1266
|
if i > 10:
|
|
1362
1267
|
break
|
|
1363
|
-
time.sleep(0.
|
|
1268
|
+
time.sleep(0.5)
|
|
1364
1269
|
i += 1
|
|
1365
1270
|
if not tmp_module:
|
|
1366
1271
|
logger.error(f"load module {tmp_module_name} failed.")
|
|
1272
|
+
# Save new module to sys.modules to support inspect.getsource().
|
|
1273
|
+
sys.modules[tmp_module_name] = tmp_module
|
|
1367
1274
|
network_cls = getattr(tmp_module, self._opt_cls_name)
|
|
1368
1275
|
if network_cls is None:
|
|
1369
1276
|
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
|
@@ -1373,21 +1280,6 @@ class SymbolTree(Observer, Observable):
|
|
|
1373
1280
|
self._modified = True
|
|
1374
1281
|
self.changed(event)
|
|
1375
1282
|
|
|
1376
|
-
def _update_container(self):
|
|
1377
|
-
"""Update instance of node in container."""
|
|
1378
|
-
for node in self.nodes():
|
|
1379
|
-
index = 0
|
|
1380
|
-
if node.get_node_type() == NodeType.CellContainer:
|
|
1381
|
-
for n in node.node_list:
|
|
1382
|
-
if not n.valid:
|
|
1383
|
-
continue
|
|
1384
|
-
if n.get_node_type() == NodeType.Tree:
|
|
1385
|
-
obj = n.symbol_tree.get_network()
|
|
1386
|
-
node.get_instance()[index] = obj
|
|
1387
|
-
else:
|
|
1388
|
-
node.get_instance()[index] = n.get_instance()
|
|
1389
|
-
index += 1
|
|
1390
|
-
|
|
1391
1283
|
def _cal_difference_set(self, input, other):
|
|
1392
1284
|
"""Calculate different set of two sets."""
|
|
1393
1285
|
set1 = set(input)
|
|
@@ -1409,50 +1301,3 @@ class SymbolTree(Observer, Observable):
|
|
|
1409
1301
|
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1410
1302
|
for p in primitives:
|
|
1411
1303
|
new_net._primitives[p] = self._origin_network._primitives[p]
|
|
1412
|
-
|
|
1413
|
-
def _update_names_for_unique(self, node: ast.AST):
|
|
1414
|
-
""" Update names of ast nodes for unique. """
|
|
1415
|
-
if isinstance(node, (ast.For, ast.If, ast.While)):
|
|
1416
|
-
self._update_names_for_unique_branchs(node)
|
|
1417
|
-
elif isinstance(node, ast.Assign):
|
|
1418
|
-
self._update_names_for_unique(node.value)
|
|
1419
|
-
for target in node.targets:
|
|
1420
|
-
self._update_names_for_unique(target)
|
|
1421
|
-
elif isinstance(node, ast.Call):
|
|
1422
|
-
if isinstance(node.func, ast.Attribute):
|
|
1423
|
-
self._update_names_for_unique(node.func.value)
|
|
1424
|
-
for arg in node.args:
|
|
1425
|
-
self._update_names_for_unique(arg)
|
|
1426
|
-
for keyword in node.keywords:
|
|
1427
|
-
self._update_names_for_unique(keyword)
|
|
1428
|
-
elif isinstance(node, ast.UnaryOp):
|
|
1429
|
-
self._update_names_for_unique(node.operand)
|
|
1430
|
-
elif isinstance(node, ast.BinOp):
|
|
1431
|
-
self._update_names_for_unique(node.left)
|
|
1432
|
-
self._update_names_for_unique(node.right)
|
|
1433
|
-
elif isinstance(node, (ast.Attribute, ast.Subscript, ast.Return)):
|
|
1434
|
-
self._update_names_for_unique(node.value)
|
|
1435
|
-
elif isinstance(node, (ast.List, ast.Tuple)):
|
|
1436
|
-
for elt in node.elts:
|
|
1437
|
-
self._update_names_for_unique(elt)
|
|
1438
|
-
elif isinstance(node, ast.Compare):
|
|
1439
|
-
for comparator in node.comparators:
|
|
1440
|
-
self._update_names_for_unique(comparator)
|
|
1441
|
-
elif isinstance(node, ast.Name):
|
|
1442
|
-
node.id = self._target_namer.get_real_arg(node.id)
|
|
1443
|
-
|
|
1444
|
-
def _update_names_for_unique_branchs(self, node: Union[ast.For, ast.If, ast.While]):
|
|
1445
|
-
""" Update names of ast nodes for unique with ast.For, ast.If or ast.While """
|
|
1446
|
-
if isinstance(node, ast.For):
|
|
1447
|
-
self._update_names_for_unique(node.target)
|
|
1448
|
-
self._update_names_for_unique(node.iter)
|
|
1449
|
-
for body in node.body:
|
|
1450
|
-
self._update_names_for_unique(body)
|
|
1451
|
-
for body in node.orelse:
|
|
1452
|
-
self._update_names_for_unique(body)
|
|
1453
|
-
elif isinstance(node, (ast.If, ast.While)):
|
|
1454
|
-
self._update_names_for_unique(node.test)
|
|
1455
|
-
for body in node.body:
|
|
1456
|
-
self._update_names_for_unique(body)
|
|
1457
|
-
for body in node.orelse:
|
|
1458
|
-
self._update_names_for_unique(body)
|