mindspore 2.2.14__cp37-cp37m-manylinux1_x86_64.whl → 2.3.0rc1__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -4
- mindspore/_akg/akg/composite/build_module.py +155 -11
- mindspore/_akg/akg/config/repository.json +38 -0
- mindspore/_akg/akg/ms/info_version_adapt.py +29 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -1
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +2 -1
- mindspore/_akg/akg/utils/composite_op_helper.py +4 -2
- mindspore/_akg/akg/utils/dump_ascend_meta.py +2 -2
- mindspore/_akg/akg/utils/gen_random.py +14 -8
- mindspore/_akg/akg/utils/op_dsl.py +11 -0
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +5 -5
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +58 -0
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +229 -0
- mindspore/_extends/parse/parser.py +155 -59
- mindspore/_extends/parse/resources.py +40 -7
- mindspore/_extends/parse/standard_method.py +124 -204
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +24 -18
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +3 -1
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_stub_tensor.py +6 -1
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +91 -48
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +5 -4
- mindspore/common/dump.py +5 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +20 -11
- mindspore/common/lazy_inline.py +58 -17
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/parameter.py +19 -4
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +251 -18
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +321 -433
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/management.py +53 -38
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +167 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +72 -38
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +33 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +7 -7
- mindspore/dataset/engine/datasets_vision.py +75 -71
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +5 -5
- mindspore/dataset/vision/transforms.py +2264 -514
- mindspore/dataset/vision/utils.py +40 -9
- mindspore/dataset/vision/validators.py +7 -1
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +40 -16
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +60 -119
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +15 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +34 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/stream.py +337 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +1 -3
- mindspore/include/api/status.h +14 -0
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/ms/base/types.h +1 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/execute.h +1 -3
- mindspore/include/mindapi/base/format.h +125 -23
- mindspore/include/mindapi/base/types.h +7 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +2044 -154
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +2044 -33
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/build_tbe_kernel.py +529 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/compiler.py +56 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/custom.py +1109 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/get_file_path.py +36 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/tbe_topi.py +556 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +0 -2
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6325 -1767
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_add_custom.h +49 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_decoder_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/include/aclnn_prompt_kv_cache.h +59 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_api/lib/libcust_opapi.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +52 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +232 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.cpp +81 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/add_custom.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.cpp +192 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/decoder_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.cpp +274 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/custom_ascendc_ops_impl/dynamic/prompt_kv_cache.py +134 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libcust_opmaster_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/op_tiling/liboptiling.so +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/inc/op_proto.h +39 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_proto/lib/linux/x86_64/libcust_opsproto_rt2.0.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu10.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnvidia_collective.so +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so.1 → libmindspore_ascend.so.2} +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +74 -56
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/multiprocessing/__init__.py +68 -0
- mindspore/nn/cell.py +86 -133
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +79 -90
- mindspore/nn/layer/basic.py +4 -80
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +71 -71
- mindspore/nn/layer/embedding.py +105 -44
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/normalization.py +46 -38
- mindspore/nn/layer/padding.py +26 -39
- mindspore/nn/layer/pooling.py +13 -9
- mindspore/nn/layer/rnn_cells.py +5 -15
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +52 -50
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/loss.py +43 -64
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +3 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +6 -5
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +33 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +7 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +2 -0
- mindspore/numpy/array_ops.py +1 -5
- mindspore/numpy/fft.py +431 -0
- mindspore/numpy/math_ops.py +54 -60
- mindspore/numpy/utils.py +3 -0
- mindspore/ops/__init__.py +5 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +16 -22
- mindspore/ops/_grad_experimental/grad_math_ops.py +68 -283
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +137 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +102 -56
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +74 -49
- mindspore/ops/_vmap/vmap_nn_ops.py +164 -89
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +133 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +248 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +147 -0
- mindspore/ops/auto_generate/gen_extend_func.py +130 -0
- mindspore/ops/auto_generate/gen_ops_def.py +4786 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +8335 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +77 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +118 -17
- mindspore/ops/composite/math_ops.py +9 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +166 -601
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +15 -133
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/extend/__init__.py +46 -0
- mindspore/ops/extend/array_func.py +152 -0
- mindspore/ops/extend/math_func.py +76 -0
- mindspore/ops/{_op_impl/tbe/atomic_addr_clean.py → extend/nn_func.py} +5 -15
- mindspore/ops/function/__init__.py +19 -11
- mindspore/ops/function/array_func.py +251 -1440
- mindspore/ops/function/clip_func.py +12 -13
- mindspore/ops/function/debug_func.py +1 -4
- mindspore/ops/function/fft_func.py +31 -0
- mindspore/ops/function/grad/grad_func.py +24 -17
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +451 -2360
- mindspore/ops/function/nn_func.py +459 -780
- mindspore/ops/function/other_func.py +4 -5
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +24 -80
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +14 -14
- mindspore/ops/functional.py +56 -62
- mindspore/ops/op_info_register.py +22 -19
- mindspore/ops/operations/__init__.py +19 -19
- mindspore/ops/operations/_grad_ops.py +20 -723
- mindspore/ops/operations/_inner_ops.py +178 -286
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +4 -34
- mindspore/ops/operations/array_ops.py +99 -2491
- mindspore/ops/operations/comm_ops.py +38 -46
- mindspore/ops/operations/custom_ops.py +8 -8
- mindspore/ops/operations/debug_ops.py +100 -31
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +3 -38
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/{rewrite/ast_transformers → ops/operations/manually_defined}/__init__.py +11 -4
- mindspore/ops/operations/manually_defined/_inner.py +61 -0
- mindspore/ops/operations/manually_defined/ops_def.py +1391 -0
- mindspore/ops/operations/math_ops.py +703 -4601
- mindspore/ops/operations/nn_ops.py +374 -1748
- mindspore/ops/operations/other_ops.py +50 -42
- mindspore/ops/operations/random_ops.py +3 -52
- mindspore/ops/primitive.py +196 -96
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +248 -0
- mindspore/ops_generate/arg_handler.py +147 -0
- mindspore/ops_generate/gen_aclnn_implement.py +266 -0
- mindspore/ops_generate/gen_ops.py +1062 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +129 -0
- mindspore/ops_generate/gen_pyboost_func.py +932 -0
- mindspore/ops_generate/gen_utils.py +188 -0
- mindspore/ops_generate/op_proto.py +138 -0
- mindspore/ops_generate/pyboost_utils.py +364 -0
- mindspore/ops_generate/template.py +238 -0
- mindspore/parallel/__init__.py +5 -4
- mindspore/parallel/_auto_parallel_context.py +21 -76
- mindspore/parallel/_cell_wrapper.py +16 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +30 -46
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +19 -7
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +1 -1
- mindspore/parallel/_utils.py +131 -6
- mindspore/parallel/algo_parameter_config.py +6 -6
- mindspore/parallel/checkpoint_transform.py +180 -196
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +345 -0
- mindspore/parallel/cluster/process_entity/_utils.py +116 -0
- mindspore/parallel/cluster/run.py +139 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +152 -0
- mindspore/parallel/shard.py +99 -2
- mindspore/profiler/common/util.py +20 -0
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +66 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +77 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +146 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +108 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +80 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +52 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +59 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +25 -277
- mindspore/profiler/parser/ascend_msprof_exporter.py +112 -132
- mindspore/profiler/parser/ascend_msprof_generator.py +68 -285
- mindspore/profiler/parser/ascend_op_generator.py +75 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +293 -135
- mindspore/profiler/parser/base_timeline_generator.py +6 -0
- mindspore/profiler/parser/framework_parser.py +3 -2
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/msadvisor_analyzer.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +5 -0
- mindspore/profiler/profiling.py +296 -166
- mindspore/rewrite/__init__.py +2 -13
- mindspore/rewrite/api/node.py +121 -35
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +635 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +107 -144
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +704 -185
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +6 -14
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +9 -19
- mindspore/scipy/__init__.py +2 -1
- mindspore/scipy/fft.py +133 -0
- mindspore/scipy/linalg.py +140 -55
- mindspore/scipy/ops.py +15 -71
- mindspore/scipy/ops_grad.py +5 -34
- mindspore/scipy/optimize/line_search.py +2 -2
- mindspore/scipy/optimize/minimize.py +1 -1
- mindspore/train/__init__.py +3 -2
- mindspore/train/_utils.py +178 -4
- mindspore/train/amp.py +167 -245
- mindspore/train/callback/_backup_and_restore.py +4 -4
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +39 -13
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +14 -8
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +7 -7
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/data_sink.py +1 -1
- mindspore/train/dataset_helper.py +13 -4
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +75 -6
- mindspore/train/model.py +24 -22
- mindspore/train/serialization.py +256 -132
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/METADATA +2 -2
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/RECORD +515 -1061
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/config/super_bar_config.json +0 -544
- mindspore/gen_ops.py +0 -273
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_creator_register.py +0 -37
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -15,25 +15,28 @@
|
|
|
15
15
|
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
|
16
16
|
import stat
|
|
17
17
|
from typing import Optional, Union, Tuple, Any, Dict, List
|
|
18
|
+
import types
|
|
18
19
|
import os
|
|
19
20
|
import sys
|
|
20
21
|
import ast
|
|
21
22
|
import importlib.util
|
|
22
23
|
import time
|
|
24
|
+
import inspect
|
|
25
|
+
from textwrap import dedent
|
|
26
|
+
from collections import OrderedDict
|
|
23
27
|
|
|
24
28
|
from mindspore.nn import Cell
|
|
25
29
|
from mindspore import log as logger
|
|
26
|
-
from .node.node import Node, TreeNode
|
|
27
|
-
from .api.node_type import NodeType
|
|
28
|
-
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder
|
|
29
|
-
from .api.scoped_value import ScopedValue, ValueType
|
|
30
30
|
from .symbol_tree_dumper import SymbolTreeDumper
|
|
31
|
-
from
|
|
32
|
-
from .
|
|
33
|
-
from .
|
|
34
|
-
from
|
|
35
|
-
|
|
36
|
-
from .
|
|
31
|
+
from ..node import Node, TreeNode, ControlFlow, CallFunction, NodeManager
|
|
32
|
+
from ..api.node_type import NodeType
|
|
33
|
+
from ..api.scoped_value import ScopedValue, ValueType
|
|
34
|
+
from ..ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, AstClassFinder, AstFunctionFinder, \
|
|
35
|
+
AstImportFinder
|
|
36
|
+
from ..common.namer import TargetNamer, NodeNamer, ClassNamer
|
|
37
|
+
from ..common.observer import Observer
|
|
38
|
+
from ..common.observable import Observable
|
|
39
|
+
from ..common.event import Event
|
|
37
40
|
|
|
38
41
|
if sys.version_info >= (3, 9):
|
|
39
42
|
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
@@ -115,27 +118,6 @@ class FieldFinder(AstFinder):
|
|
|
115
118
|
return self._result
|
|
116
119
|
|
|
117
120
|
|
|
118
|
-
class IfFixer(ast.NodeTransformer):
|
|
119
|
-
"""
|
|
120
|
-
Fix ast.If if body is empty while orelse is not empty.
|
|
121
|
-
"""
|
|
122
|
-
|
|
123
|
-
def visit_If(self, node: ast.If) -> Any:
|
|
124
|
-
"""Visit a node of type ast.If."""
|
|
125
|
-
if not node.body and node.orelse:
|
|
126
|
-
node.body.append(ast.Pass())
|
|
127
|
-
return super().generic_visit(node)
|
|
128
|
-
|
|
129
|
-
def fix(self, node):
|
|
130
|
-
"""
|
|
131
|
-
Fix ast.If node in `node` if whose body is empty while whose orelse is not empty.
|
|
132
|
-
|
|
133
|
-
Args:
|
|
134
|
-
node (ast.AST): An ast node to be fixed.
|
|
135
|
-
"""
|
|
136
|
-
self.generic_visit(node)
|
|
137
|
-
|
|
138
|
-
|
|
139
121
|
class SymbolTree(Observer, Observable, NodeManager):
|
|
140
122
|
"""
|
|
141
123
|
A symbol-tree usually corresponding to forward method of a network.
|
|
@@ -147,13 +129,16 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
147
129
|
origin_network (Cell): A handler to original network instance.
|
|
148
130
|
module_ast (ast.Module): An instance of ast.AST represents ast node of original network.
|
|
149
131
|
"""
|
|
132
|
+
# whether parse CallFunction node inserted by user.
|
|
133
|
+
_unparse_inserted_function = True
|
|
150
134
|
|
|
151
135
|
def __init__(self, origin_network: Cell, module_ast: ast.Module):
|
|
152
136
|
Observer.__init__(self)
|
|
153
137
|
Observable.__init__(self)
|
|
154
138
|
self._node_namer = NodeNamer()
|
|
155
139
|
self._node_namer.add_name('obj')
|
|
156
|
-
NodeManager.__init__(self
|
|
140
|
+
NodeManager.__init__(self)
|
|
141
|
+
NodeManager.set_manager_node_namer(self, self._node_namer)
|
|
157
142
|
NodeManager.reg_observer(self, observer=self)
|
|
158
143
|
# init unique-namers
|
|
159
144
|
self._target_namer = TargetNamer()
|
|
@@ -169,63 +154,69 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
169
154
|
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
170
155
|
self._deleted_field = {}
|
|
171
156
|
self._deleted_node = []
|
|
172
|
-
|
|
173
|
-
self.
|
|
157
|
+
# {ast_function: [import_asts]}
|
|
158
|
+
self._external_ast: Dict[ast.FunctionDef, list] = OrderedDict()
|
|
159
|
+
# {ast_class: [import_asts]}
|
|
160
|
+
self._father_class_ast: Dict[ast.ClassDef, list] = OrderedDict()
|
|
174
161
|
self._modified = False
|
|
175
|
-
self._tmp_file_limits = 20
|
|
176
|
-
self._tmp_files = []
|
|
177
162
|
self._saved_file_name = "./network_define.py"
|
|
178
163
|
# used to insert "sys.path.append(xxx)"
|
|
179
164
|
self._net_file_paths = []
|
|
180
165
|
self._tmp_import_strs = []
|
|
181
|
-
self._tmp_unmodified_strees: {type,
|
|
166
|
+
self._tmp_unmodified_strees: {type, List[SymbolTree]} = {}
|
|
182
167
|
self._tmp_replacers = []
|
|
183
|
-
#
|
|
184
|
-
|
|
185
|
-
#
|
|
186
|
-
self.
|
|
187
|
-
|
|
188
|
-
def __del__(self):
|
|
189
|
-
for tmp_file in self._tmp_files:
|
|
190
|
-
tmp_file.close()
|
|
168
|
+
# user custom codes
|
|
169
|
+
self._custom_codes: List[ast.AST] = []
|
|
170
|
+
# local primitive instances initialized during forward method, e.g. abs_inst = P.Abs()
|
|
171
|
+
self._local_prim_inits: List[Node] = []
|
|
191
172
|
|
|
192
173
|
@staticmethod
|
|
193
174
|
def _remove_unused_import(module_ast):
|
|
194
175
|
"""remove unused import in self._module_ast"""
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
if
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
if isinstance(
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
176
|
+
import_nodes: List[Union[ast.Import, ast.ImportFrom]] = []
|
|
177
|
+
|
|
178
|
+
def is_divider(ast_node):
|
|
179
|
+
"""judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
|
|
180
|
+
return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
|
|
181
|
+
|
|
182
|
+
for ast_node in module_ast.body[:]:
|
|
183
|
+
if isinstance(ast_node, (ast.Import, ast.ImportFrom)):
|
|
184
|
+
import_nodes.append(ast_node)
|
|
185
|
+
if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
|
|
186
|
+
str_checker = StrChecker(ast_node)
|
|
187
|
+
for import_node in import_nodes:
|
|
188
|
+
for alias in import_node.names[:]:
|
|
189
|
+
name = alias.asname if alias.asname else alias.name
|
|
190
|
+
if name == '*':
|
|
191
|
+
continue
|
|
192
|
+
if not str_checker.check(name):
|
|
193
|
+
import_node.names.remove(alias)
|
|
194
|
+
if not import_node.names:
|
|
195
|
+
module_ast.body.remove(import_node)
|
|
196
|
+
if is_divider(ast_node):
|
|
197
|
+
import_nodes.clear()
|
|
213
198
|
|
|
214
199
|
@staticmethod
|
|
215
200
|
def _remove_duplicated_import(module_ast):
|
|
216
201
|
"""Remove duplicated import of 'net'."""
|
|
217
202
|
imports = set()
|
|
218
203
|
futures = set()
|
|
219
|
-
|
|
204
|
+
names = set()
|
|
220
205
|
|
|
221
206
|
class TransImportNode(ast.NodeTransformer):
|
|
222
207
|
"""Find all import nodes from input ast node."""
|
|
223
208
|
|
|
224
209
|
def visit_ClassDef(self, node: ast.ClassDef) -> Any:
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
210
|
+
if node.name not in names:
|
|
211
|
+
names.add(node.name)
|
|
212
|
+
return node
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
|
216
|
+
if node.name not in names:
|
|
217
|
+
names.add(node.name)
|
|
228
218
|
return node
|
|
219
|
+
return None
|
|
229
220
|
|
|
230
221
|
def visit_Try(self, node: ast.Try) -> Any:
|
|
231
222
|
if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
|
|
@@ -233,12 +224,14 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
233
224
|
if import_str not in imports:
|
|
234
225
|
imports.add(import_str)
|
|
235
226
|
return node
|
|
227
|
+
return None
|
|
236
228
|
|
|
237
229
|
def visit_Import(self, node: ast.Import) -> Any:
|
|
238
230
|
import_str = astunparse.unparse(node)
|
|
239
231
|
if import_str not in imports:
|
|
240
232
|
imports.add(import_str)
|
|
241
233
|
return node
|
|
234
|
+
return None
|
|
242
235
|
|
|
243
236
|
def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
|
|
244
237
|
"""
|
|
@@ -259,21 +252,225 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
259
252
|
# remove "__future__" module
|
|
260
253
|
if node.module == '__future__':
|
|
261
254
|
futures.add(node.module)
|
|
262
|
-
return
|
|
255
|
+
return None
|
|
263
256
|
# remove modules which have been defined in the code file
|
|
264
257
|
# it occurs when class A is a father class and other sub-classes import A
|
|
265
258
|
for alias in node.names[:]:
|
|
266
|
-
if alias.name in
|
|
259
|
+
if alias.name in names:
|
|
267
260
|
node.names.remove(alias)
|
|
268
261
|
# if the alias(es) in node.names are all removed, this import statement should be removed
|
|
269
262
|
if not node.names:
|
|
270
|
-
return
|
|
263
|
+
return None
|
|
271
264
|
return node
|
|
272
|
-
return
|
|
265
|
+
return None
|
|
273
266
|
|
|
274
267
|
get_node_handler = TransImportNode()
|
|
275
268
|
get_node_handler.generic_visit(module_ast)
|
|
276
269
|
|
|
270
|
+
@staticmethod
|
|
271
|
+
def _remove_arg_annotations(module_ast):
|
|
272
|
+
"""Remove annotations in ast.arg to avoid 'xxx is not defined'."""
|
|
273
|
+
ast_args: List[ast.arg] = AstFinder(module_ast).find_all(ast.arg)
|
|
274
|
+
for ast_arg in ast_args:
|
|
275
|
+
ast_arg.annotation = None
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def _check_import(import_path: str, import_module: str):
|
|
279
|
+
"""
|
|
280
|
+
Check whether import operation is valid when importing module from specific path.
|
|
281
|
+
"""
|
|
282
|
+
if import_path not in sys.path:
|
|
283
|
+
sys.path.append(import_path)
|
|
284
|
+
try:
|
|
285
|
+
importlib.import_module(name=import_module)
|
|
286
|
+
except (ValueError, ImportError) as e:
|
|
287
|
+
logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
|
|
288
|
+
return False
|
|
289
|
+
except Exception as e: # pylint: disable=W0703
|
|
290
|
+
logger.info(f"Test import {import_module} from {import_path} failed: {e}.")
|
|
291
|
+
return False
|
|
292
|
+
return True
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
def _process_relative_import(import_node: Union[ast.Import, ast.ImportFrom], file_path: str):
|
|
296
|
+
"""Process relative imports"""
|
|
297
|
+
file_path = os.path.normcase(file_path)
|
|
298
|
+
file_path = os.path.normpath(file_path)
|
|
299
|
+
if isinstance(import_node, ast.ImportFrom):
|
|
300
|
+
# pad the ImportFrom with parent path
|
|
301
|
+
# e.g. from ..C import xxx -> from A.B.C import xxx
|
|
302
|
+
import_module = SymbolTree._get_valid_import_info(import_node, file_path)
|
|
303
|
+
if import_module:
|
|
304
|
+
import_node = ast.ImportFrom(module=import_module, names=import_node.names, level=0)
|
|
305
|
+
return import_node
|
|
306
|
+
|
|
307
|
+
@staticmethod
|
|
308
|
+
def _get_valid_import_info(import_node: ast.ImportFrom, file_path: str):
|
|
309
|
+
"""Get valid import info while import_node.module is at form of relative path"""
|
|
310
|
+
file_path = os.path.dirname(os.path.abspath(file_path))
|
|
311
|
+
# get real path from import_node.level
|
|
312
|
+
# from .(A) import xxx: current path
|
|
313
|
+
# from ..(A) import xxx: last level path
|
|
314
|
+
level = import_node.level
|
|
315
|
+
# from A import xxx: it does not need to pad, directly return the module name
|
|
316
|
+
if level == 0:
|
|
317
|
+
return import_node.module
|
|
318
|
+
if level > 1:
|
|
319
|
+
for _ in range(level - 1):
|
|
320
|
+
file_path = os.path.dirname(file_path)
|
|
321
|
+
file_path_tmp = file_path[:]
|
|
322
|
+
max_level_count = file_path.count(os.path.sep) - 1
|
|
323
|
+
level_count = 0
|
|
324
|
+
# suffix is the module_name, e.g. 'A' in 'from ..(A) import xxx'
|
|
325
|
+
suffix = ''
|
|
326
|
+
if import_node.module:
|
|
327
|
+
suffix = '.' + import_node.module
|
|
328
|
+
while level_count < max_level_count:
|
|
329
|
+
file_path_tmp = os.path.dirname(file_path_tmp)
|
|
330
|
+
if file_path_tmp not in sys.path:
|
|
331
|
+
logger.debug(f"{file_path_tmp} not in sys.path, try upper level.")
|
|
332
|
+
level_count += 1
|
|
333
|
+
continue
|
|
334
|
+
import_module = file_path[len(file_path_tmp) + 1:].replace(os.path.sep, '.') + suffix
|
|
335
|
+
if SymbolTree._check_import(file_path_tmp, import_module):
|
|
336
|
+
# try test code success
|
|
337
|
+
return import_module
|
|
338
|
+
# test import ast failed, try upper level
|
|
339
|
+
level_count += 1
|
|
340
|
+
logger.info(f"Try upper level.")
|
|
341
|
+
# try codes with all level failed
|
|
342
|
+
logger.info(f"Test import code: {astunparse.unparse(import_node).strip()} failed, ignore this import code.")
|
|
343
|
+
return None
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def insert_to_ast_while_insert_input(new_node: Node, node_manager: NodeManager):
|
|
347
|
+
"""update ast when inserting NodeType.Input node"""
|
|
348
|
+
if not isinstance(node_manager, (SymbolTree, CallFunction)):
|
|
349
|
+
raise ValueError(f"Only support insert Input node into a SymbolTree or a node with type of "
|
|
350
|
+
f"CallFunction, but get {type(node_manager)}")
|
|
351
|
+
# insert a new input
|
|
352
|
+
node_manager.get_input_nodes().append(new_node)
|
|
353
|
+
ast_function: ast.FunctionDef = node_manager.get_manager_ast()
|
|
354
|
+
arg: str = new_node.get_targets()[0].value
|
|
355
|
+
ast_arg = ast.arg(arg=arg, annotation=None, type_comment=None)
|
|
356
|
+
AstModifier.append_arg_to_function(ast_function, ast_arg)
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def insert_to_ast_while_insert_cell_primitive(new_node: Node, base_node: Node, before_node: bool,
|
|
360
|
+
node_manager: NodeManager, stree):
|
|
361
|
+
"""update ast when inserting NodeType.CallCell or NodeType.CallPrimitive node"""
|
|
362
|
+
# create a new assign statement
|
|
363
|
+
ast_assign = new_node.get_ast()
|
|
364
|
+
if ast_assign is None:
|
|
365
|
+
func_name = stree.unique_func_name(new_node.get_name())
|
|
366
|
+
new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
|
|
367
|
+
ast_assign = new_node.update_ast_node()
|
|
368
|
+
if not isinstance(ast_assign, ast.Assign):
|
|
369
|
+
raise ValueError(f"Only support insert ast.Assign or Input now, but get {type(ast_assign)}")
|
|
370
|
+
# Save instance into _origin_network.
|
|
371
|
+
setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
|
|
372
|
+
# Insert ast to __init__ function
|
|
373
|
+
if isinstance(new_node, TreeNode):
|
|
374
|
+
init_code = f"{new_node.get_func_name()} = " \
|
|
375
|
+
f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
|
|
376
|
+
else:
|
|
377
|
+
init_code = f"{new_node.get_func_name()} = obj.{new_node.get_name()}"
|
|
378
|
+
init_ast = ast.parse(init_code).body[0]
|
|
379
|
+
AstModifier.insert_ast_to_function(stree.get_init_func_ast(), init_ast)
|
|
380
|
+
# Insert ast to construct_function/class_internal_function
|
|
381
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
382
|
+
ast_node_manager = node_manager.get_manager_ast()
|
|
383
|
+
if not ast_node_manager:
|
|
384
|
+
raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
|
|
385
|
+
"when inserting the ast.")
|
|
386
|
+
AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def insert_to_ast_while_insert_function(new_node: CallFunction, base_node: Node, before_node: bool,
|
|
390
|
+
node_manager: NodeManager, stree: 'SymbolTree'):
|
|
391
|
+
"""update ast when inserting NodeType.CallFunction node"""
|
|
392
|
+
func_name = str(new_node.get_func_name())
|
|
393
|
+
# create a new assign statement
|
|
394
|
+
ast_assign = new_node.get_ast()
|
|
395
|
+
if ast_assign is None:
|
|
396
|
+
ast_assign = new_node.update_ast_node()
|
|
397
|
+
# Insert ast to node_manager
|
|
398
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
399
|
+
ast_node_manager = node_manager.get_manager_ast()
|
|
400
|
+
if not ast_node_manager:
|
|
401
|
+
raise RuntimeError(f"ast_node_manager is None in node_manager {node_manager.get_manager_name()} "
|
|
402
|
+
"when inserting the ast.")
|
|
403
|
+
AstModifier.insert_ast_to_ast(ast_node_manager, ast_assign, ast_base_node, before_node)
|
|
404
|
+
# Ignore Python builtin functions
|
|
405
|
+
func_obj = new_node.get_instance()
|
|
406
|
+
if isinstance(func_obj, types.BuiltinFunctionType):
|
|
407
|
+
logger.warning(f"Ignore built in function: {func_name}")
|
|
408
|
+
return
|
|
409
|
+
# get ast.FunctionDef
|
|
410
|
+
source_code = inspect.getsource(func_obj)
|
|
411
|
+
ast_functiondef = ast.parse(dedent(source_code)).body[0]
|
|
412
|
+
if SymbolTree._unparse_inserted_function or not isinstance(ast_functiondef, ast.FunctionDef):
|
|
413
|
+
logger.debug(f"import '{func_name}' to access function object")
|
|
414
|
+
# add import to make sure that the function object can be accessed.
|
|
415
|
+
module = inspect.getmodule(func_obj)
|
|
416
|
+
top_node_manager = node_manager.get_top_manager()
|
|
417
|
+
belonging_ast = None if isinstance(top_node_manager, SymbolTree) else top_node_manager.get_manager_ast()
|
|
418
|
+
stree.add_import(module, func_name, belonging_ast)
|
|
419
|
+
return
|
|
420
|
+
# parse nodes in inserted function.
|
|
421
|
+
new_node.set_manager_ast(ast_functiondef)
|
|
422
|
+
new_node.set_manager_node_namer(stree.get_node_namer())
|
|
423
|
+
stree.get_external_ast()[ast_functiondef] = []
|
|
424
|
+
# import module which function defined in
|
|
425
|
+
func_file_path = inspect.getabsfile(func_obj)
|
|
426
|
+
stree.save_imports_from_file(func_file_path, ast_functiondef)
|
|
427
|
+
# expand ast codes in function
|
|
428
|
+
from ..ast_helpers import AstFlattener
|
|
429
|
+
ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name], stree)
|
|
430
|
+
# parse ast codes into CallFunction Node
|
|
431
|
+
from ..parsers import ParserRegister
|
|
432
|
+
parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
433
|
+
parser.process(stree, ast_functiondef, node_manager=new_node)
|
|
434
|
+
|
|
435
|
+
@staticmethod
|
|
436
|
+
def insert_to_ast_while_insert_node(new_node: Node, base_node: Node, before_node: bool):
|
|
437
|
+
""" insert_to_ast_while_insert_node. """
|
|
438
|
+
stree = new_node.get_belong_symbol_tree()
|
|
439
|
+
if not stree:
|
|
440
|
+
raise ValueError(f"When inserting node to ast, the belonging symbol tree of new_node is None.")
|
|
441
|
+
node_manager = new_node.get_node_manager()
|
|
442
|
+
if not isinstance(node_manager, (SymbolTree, CallFunction, ControlFlow)):
|
|
443
|
+
raise ValueError(f"When inserting node to ast, the node_manager of new_node {new_node.get_name()} can "
|
|
444
|
+
f"only be one of [SymbolTree, CallFunction, ControlFlow], but get {type(node_manager)}")
|
|
445
|
+
if new_node.get_node_type() == NodeType.Input:
|
|
446
|
+
SymbolTree.insert_to_ast_while_insert_input(new_node, node_manager)
|
|
447
|
+
elif new_node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
|
448
|
+
SymbolTree.insert_to_ast_while_insert_cell_primitive(new_node, base_node, before_node, node_manager,
|
|
449
|
+
stree)
|
|
450
|
+
elif new_node.get_node_type() == NodeType.CallFunction:
|
|
451
|
+
SymbolTree.insert_to_ast_while_insert_function(new_node, base_node, before_node, node_manager, stree)
|
|
452
|
+
else:
|
|
453
|
+
raise ValueError(f"When insert node '{new_node.get_name()}' into ast, the type of node can only be "
|
|
454
|
+
f"one of [Input, CallCell, CallPrimitive, CallFunction, Tree], but got "
|
|
455
|
+
f"{new_node.get_node_type()}.")
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def get_node_full_name(node: Node) -> str:
|
|
459
|
+
"""Get full name of node"""
|
|
460
|
+
name = node.get_manager_name() if isinstance(node, NodeManager) else node.get_name()
|
|
461
|
+
# traverse node_manager with type of Node
|
|
462
|
+
node_manager = node.get_node_manager()
|
|
463
|
+
while isinstance(node_manager, Node):
|
|
464
|
+
name = f"{node_manager.get_manager_name()}.{name}"
|
|
465
|
+
node_manager = node_manager.get_node_manager()
|
|
466
|
+
# type of node_manager is SymbolTree now
|
|
467
|
+
name = f"{node_manager.get_manager_name()}.{name}"
|
|
468
|
+
return name
|
|
469
|
+
|
|
470
|
+
def local_prim_inits(self) -> List[Node]:
|
|
471
|
+
"""get local primitives constructed during forward method"""
|
|
472
|
+
return self._local_prim_inits
|
|
473
|
+
|
|
277
474
|
def finish_build(self):
|
|
278
475
|
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
279
476
|
self.add_event(Event.TopologicalChangeEvent)
|
|
@@ -333,7 +530,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
333
530
|
corresponding network class.
|
|
334
531
|
"""
|
|
335
532
|
self._root_ast = ast_node
|
|
336
|
-
NodeManager.
|
|
533
|
+
NodeManager.set_manager_ast(self, ast_node)
|
|
337
534
|
|
|
338
535
|
def get_class_ast(self):
|
|
339
536
|
"""
|
|
@@ -346,7 +543,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
346
543
|
|
|
347
544
|
def set_class_ast(self, ast_node: ast.ClassDef):
|
|
348
545
|
"""
|
|
349
|
-
Setter of `
|
|
546
|
+
Setter of `_class_ast`.
|
|
350
547
|
|
|
351
548
|
Args:
|
|
352
549
|
ast_node (ast.ClassDef): An instance of ast.ClassDef represents ast node of corresponding network class.
|
|
@@ -420,19 +617,6 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
420
617
|
"""Get _father_class_ast"""
|
|
421
618
|
return self._father_class_ast
|
|
422
619
|
|
|
423
|
-
def get_imported_modules(self, file_path: str):
|
|
424
|
-
"""Get all modules and module_paths in file of `file_path` ."""
|
|
425
|
-
return self._imported_modules.get(file_path, {})
|
|
426
|
-
|
|
427
|
-
def save_imported_modules(self, file_path: str, module: str, names: List[str]):
|
|
428
|
-
"""Save module and names into _imported_modules."""
|
|
429
|
-
imported_modules = self.get_imported_modules(file_path)
|
|
430
|
-
if imported_modules.get(module):
|
|
431
|
-
imported_modules[module].extend(names)
|
|
432
|
-
else:
|
|
433
|
-
imported_modules[module] = names
|
|
434
|
-
self._imported_modules[file_path] = imported_modules
|
|
435
|
-
|
|
436
620
|
def get_node_inputs(self, node_or_name: Union[Node, str]) -> [Node]:
|
|
437
621
|
"""
|
|
438
622
|
Getter of inputs in topological relation of current 'node_or_name'.
|
|
@@ -469,7 +653,13 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
469
653
|
return []
|
|
470
654
|
if real_node.get_node_type() == NodeType.Output:
|
|
471
655
|
return []
|
|
472
|
-
|
|
656
|
+
node_users = []
|
|
657
|
+
for target_users in real_node.get_target_users().values():
|
|
658
|
+
if not target_users:
|
|
659
|
+
continue
|
|
660
|
+
if target_users not in node_users:
|
|
661
|
+
node_users.extend(target_users)
|
|
662
|
+
return node_users
|
|
473
663
|
|
|
474
664
|
def before(self, node_or_name: Union[Node, str]) -> Position:
|
|
475
665
|
"""
|
|
@@ -566,8 +756,8 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
566
756
|
if base_node is not None:
|
|
567
757
|
stree = base_node.get_belong_symbol_tree()
|
|
568
758
|
if stree is not None and stree is not self:
|
|
569
|
-
raise
|
|
570
|
-
|
|
759
|
+
raise ValueError(f"Position is not in current SymbolTree, node:{stree.get_ori_cls_name()}, "
|
|
760
|
+
f"current: {self.get_ori_cls_name()}.")
|
|
571
761
|
|
|
572
762
|
# Check if node is inserted between Input node
|
|
573
763
|
if base_node is not None and base_node.get_node_type() == NodeType.Input:
|
|
@@ -599,7 +789,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
599
789
|
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
600
790
|
if insert_to_ast:
|
|
601
791
|
# update init-function-ast and construct-function-ast
|
|
602
|
-
self.insert_to_ast_while_insert_node(new_node, base_node, before_node
|
|
792
|
+
self.insert_to_ast_while_insert_node(new_node, base_node, before_node)
|
|
603
793
|
else:
|
|
604
794
|
node_manager.insert_node(new_node, base_node, before_node, insert_to_ast)
|
|
605
795
|
|
|
@@ -668,7 +858,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
668
858
|
# check param_name duplicated
|
|
669
859
|
if node_manager is None:
|
|
670
860
|
node_manager = self
|
|
671
|
-
for input_node in node_manager.
|
|
861
|
+
for input_node in node_manager.get_input_nodes():
|
|
672
862
|
targets = input_node.get_targets()
|
|
673
863
|
if len(targets) != 1:
|
|
674
864
|
raise RuntimeError("targets should have 1 elements")
|
|
@@ -782,11 +972,15 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
782
972
|
|
|
783
973
|
if node_manager is self:
|
|
784
974
|
NodeManager.erase_node(self, node)
|
|
785
|
-
|
|
975
|
+
if isinstance(node, ControlFlow):
|
|
976
|
+
ret = AstModifier.earse_ast_of_control_flow(self._root_ast.body, node.get_ast(), node.is_orelse)
|
|
977
|
+
else:
|
|
978
|
+
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
786
979
|
if not ret:
|
|
787
980
|
raise RuntimeError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
788
981
|
else:
|
|
789
982
|
node_manager.erase_node(node)
|
|
983
|
+
node.set_belong_symbol_tree(None)
|
|
790
984
|
self._deleted_node.append(node.get_name())
|
|
791
985
|
return node
|
|
792
986
|
|
|
@@ -815,7 +1009,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
815
1009
|
for node in new_nodes:
|
|
816
1010
|
self.insert_node(node, base_node, False, node_manager, True)
|
|
817
1011
|
base_node = node
|
|
818
|
-
|
|
1012
|
+
self.erase_node(old_node)
|
|
819
1013
|
return new_nodes[-1]
|
|
820
1014
|
|
|
821
1015
|
def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
|
|
@@ -836,7 +1030,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
836
1030
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node)
|
|
837
1031
|
|
|
838
1032
|
new_arg, old_arg = node.set_arg(arg, index)
|
|
839
|
-
|
|
1033
|
+
node.get_node_manager().on_update_arg(node, index, old_arg, new_arg)
|
|
840
1034
|
|
|
841
1035
|
def set_node_arg_by_node(self, dst_node: Union[Node, str], arg_idx: int, src_node: Union[Node, str],
|
|
842
1036
|
out_idx: Optional[int] = None):
|
|
@@ -873,7 +1067,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
873
1067
|
raise RuntimeError("out_idx out of range: ", out_idx)
|
|
874
1068
|
new_arg = targets[out_idx]
|
|
875
1069
|
real_dst_node.set_arg(new_arg, arg_idx)
|
|
876
|
-
|
|
1070
|
+
real_dst_node.get_node_manager().on_update_arg_by_node(real_dst_node, arg_idx, real_src_node, out_idx)
|
|
877
1071
|
|
|
878
1072
|
def unique_name(self, name: str):
|
|
879
1073
|
"""Get a unique name in the symboltree"""
|
|
@@ -915,10 +1109,13 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
915
1109
|
node.set_targets(targets)
|
|
916
1110
|
self._topo_mgr.on_update_target(node, index, old_target, target)
|
|
917
1111
|
|
|
918
|
-
def all_nodes(self):
|
|
1112
|
+
def all_nodes(self, subtree_nodes: bool = True):
|
|
919
1113
|
"""
|
|
920
1114
|
Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree.
|
|
921
1115
|
|
|
1116
|
+
Args:
|
|
1117
|
+
subtree_nodes (bool): Whether include nodes in subtree. Default: True.
|
|
1118
|
+
|
|
922
1119
|
Returns:
|
|
923
1120
|
A list of nodes.
|
|
924
1121
|
"""
|
|
@@ -930,9 +1127,10 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
930
1127
|
for node in node_manager.nodes():
|
|
931
1128
|
if isinstance(node, NodeManager):
|
|
932
1129
|
node_managers.append(node)
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1130
|
+
if subtree_nodes:
|
|
1131
|
+
for tree_node in self.get_tree_nodes():
|
|
1132
|
+
stree = tree_node.symbol_tree
|
|
1133
|
+
nodes.extend(stree.all_nodes())
|
|
936
1134
|
return nodes
|
|
937
1135
|
|
|
938
1136
|
def get_node_from_name(self, node_name: str):
|
|
@@ -956,13 +1154,16 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
956
1154
|
node_managers.append(node)
|
|
957
1155
|
return None
|
|
958
1156
|
|
|
959
|
-
def
|
|
1157
|
+
def get_node_tabulate(self, all_nodes: bool = False) -> str:
|
|
960
1158
|
"""
|
|
961
|
-
|
|
1159
|
+
Get nodes information and nodes' topological relations.
|
|
962
1160
|
|
|
963
1161
|
Args:
|
|
964
1162
|
all_nodes (bool): Print nodes out of construct functions, such as nodes in CallFunction
|
|
965
1163
|
nodes, CellContainer nodes and sub symbol trees.
|
|
1164
|
+
|
|
1165
|
+
Returns:
|
|
1166
|
+
String of nodes' information and topological relations.
|
|
966
1167
|
"""
|
|
967
1168
|
try:
|
|
968
1169
|
from tabulate import tabulate # pylint: disable=unused-import,reportMissingModuleSource
|
|
@@ -971,18 +1172,19 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
971
1172
|
"which could not be found on this machine. Run `pip "
|
|
972
1173
|
"install tabulate` to install the library.")
|
|
973
1174
|
return ""
|
|
974
|
-
|
|
1175
|
+
dump_str = NodeManager.dump(self, self.get_manager_name())
|
|
975
1176
|
if all_nodes:
|
|
976
1177
|
node_managers = [self]
|
|
977
1178
|
while node_managers:
|
|
978
1179
|
node_manager = node_managers.pop()
|
|
979
1180
|
for node in node_manager.nodes():
|
|
980
1181
|
if isinstance(node, NodeManager):
|
|
981
|
-
|
|
1182
|
+
dump_str += node.dump(SymbolTree.get_node_full_name(node))
|
|
982
1183
|
node_managers.append(node)
|
|
983
1184
|
for tree_node in self.get_tree_nodes():
|
|
984
1185
|
stree = tree_node.symbol_tree
|
|
985
|
-
stree.
|
|
1186
|
+
dump_str += stree.get_node_tabulate(all_nodes)
|
|
1187
|
+
return dump_str
|
|
986
1188
|
|
|
987
1189
|
def dump(self):
|
|
988
1190
|
"""Dump graph."""
|
|
@@ -1019,20 +1221,76 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1019
1221
|
|
|
1020
1222
|
return False
|
|
1021
1223
|
|
|
1022
|
-
def
|
|
1224
|
+
def deduplicate_unmodified_stree(self, code_bodies):
|
|
1225
|
+
"""
|
|
1226
|
+
Init function may be different even if stree is not modified manually, when subnets in stree is
|
|
1227
|
+
initialized by different arguments.
|
|
1228
|
+
In this case, we need to wait for code_bodies being fully generated, so that the name of subnets
|
|
1229
|
+
will be updated, then we can deduplicate again according to ast of init function.
|
|
1230
|
+
"""
|
|
1231
|
+
# prepare AstClassFinder and AstReplacer
|
|
1232
|
+
if sys.version_info >= (3, 9):
|
|
1233
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1234
|
+
name_replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1235
|
+
else:
|
|
1236
|
+
class_finder = AstClassFinder(ast.Module(body=code_bodies))
|
|
1237
|
+
name_replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1238
|
+
# deduplicate all unmodified strees in self._tmp_unmodified_strees
|
|
1239
|
+
deduplicated = False
|
|
1240
|
+
for _, unmodified_strees in self._tmp_unmodified_strees.items():
|
|
1241
|
+
if len(unmodified_strees) <= 1:
|
|
1242
|
+
continue
|
|
1243
|
+
init_func_codes = [astunparse.unparse(stree.get_init_func_ast()) for stree in unmodified_strees]
|
|
1244
|
+
# If the index of an element is not its own, it means that it is a duplicate element
|
|
1245
|
+
to_be_erase = []
|
|
1246
|
+
for idx, code in enumerate(init_func_codes):
|
|
1247
|
+
first_idx = init_func_codes.index(code)
|
|
1248
|
+
if first_idx != idx:
|
|
1249
|
+
first_stree_cls_name = unmodified_strees[first_idx].get_opt_cls_name()
|
|
1250
|
+
duplicated_stree_cls_name = unmodified_strees[idx].get_opt_cls_name()
|
|
1251
|
+
logger.debug(f"replace stree:{duplicated_stree_cls_name} to {first_stree_cls_name}.")
|
|
1252
|
+
# delete duplicated class from code_bodies
|
|
1253
|
+
results = class_finder.find_all(duplicated_stree_cls_name)
|
|
1254
|
+
for ast_cls in results:
|
|
1255
|
+
code_bodies.remove(ast_cls)
|
|
1256
|
+
# replace name of duplicated class in code_bodies to first_stree_cls_name
|
|
1257
|
+
name_replacer.replace_all(duplicated_stree_cls_name, first_stree_cls_name)
|
|
1258
|
+
# record deduplicated stree
|
|
1259
|
+
to_be_erase.append(idx)
|
|
1260
|
+
deduplicated = True
|
|
1261
|
+
# remove class in self._tmp_unmodified_strees
|
|
1262
|
+
for idx in reversed(to_be_erase):
|
|
1263
|
+
unmodified_strees.pop(idx)
|
|
1264
|
+
|
|
1265
|
+
# the name of subnets is updated, so we need to deduplicate again.
|
|
1266
|
+
if deduplicated:
|
|
1267
|
+
self._tmp_replacers.append(name_replacer)
|
|
1268
|
+
self.deduplicate_unmodified_stree(code_bodies)
|
|
1269
|
+
|
|
1270
|
+
def update_unmodified_stree(self, stree, code_bodies) -> bool:
|
|
1023
1271
|
"""
|
|
1024
1272
|
For the unmodified symbol tree, only one definition code remains in the generated code.
|
|
1025
1273
|
Everywhere else calling this symbol tree will use the class in this definition code.
|
|
1026
1274
|
"""
|
|
1027
1275
|
# all modified ast.ClassDef will be exported to code
|
|
1028
1276
|
if stree.is_modified():
|
|
1277
|
+
logger.debug(f"stree:{stree.get_opt_cls_name()} is modified.")
|
|
1029
1278
|
return False
|
|
1030
1279
|
# all un-modified ast.ClassDef only keep one instance
|
|
1031
|
-
|
|
1032
|
-
if
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1280
|
+
unmodified_strees = self._tmp_unmodified_strees.get(type(stree.get_origin_network()))
|
|
1281
|
+
if not unmodified_strees:
|
|
1282
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())] = [stree]
|
|
1283
|
+
logger.debug(f"stree:{stree.get_opt_cls_name()} is the first stree.")
|
|
1284
|
+
return False
|
|
1285
|
+
# Init function may be different even if stree is not modified, when subnets in stree is
|
|
1286
|
+
# initialized by different arguments.
|
|
1287
|
+
first_stree = unmodified_strees[0]
|
|
1288
|
+
first_stree_cls_name = first_stree.get_opt_cls_name()
|
|
1289
|
+
if astunparse.unparse(stree.get_init_func_ast()) != astunparse.unparse(first_stree.get_init_func_ast()):
|
|
1290
|
+
# init ast may be updated after inserting subtrees of stree, so we need to save unmodified strees
|
|
1291
|
+
# and deduplicate later
|
|
1292
|
+
self._tmp_unmodified_strees[type(stree.get_origin_network())].append(stree)
|
|
1293
|
+
logger.debug(f"init func different, stree:{stree.get_opt_cls_name()}, first_stree:{first_stree_cls_name}.")
|
|
1036
1294
|
return False
|
|
1037
1295
|
# Un-modified ast.ClassDef already exist in code_bodies,
|
|
1038
1296
|
# replace class name to class name of first un-modified ast.ClassDef.
|
|
@@ -1040,66 +1298,105 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1040
1298
|
replacer = AstReplacer(ast.Module(body=code_bodies, type_ignores=[]))
|
|
1041
1299
|
else:
|
|
1042
1300
|
replacer = AstReplacer(ast.Module(body=code_bodies))
|
|
1043
|
-
|
|
1301
|
+
logger.debug(f"replace stree:{stree.get_opt_cls_name()} to {first_stree_cls_name}.")
|
|
1302
|
+
replacer.replace_all(stree.get_class_ast().name, first_stree_cls_name)
|
|
1044
1303
|
self._tmp_replacers.append(replacer)
|
|
1045
1304
|
return True
|
|
1046
1305
|
|
|
1047
|
-
def
|
|
1306
|
+
def init_code_bodies(self, code_bodies: list) -> int:
|
|
1307
|
+
"""Init code bodied"""
|
|
1308
|
+
# Add basic imports
|
|
1309
|
+
code_bodies.append(ast.Import([ast.alias(name='sys', asname=None)]))
|
|
1310
|
+
code_bodies.append(ast.Import([ast.alias(name='mindspore', asname=None)]))
|
|
1311
|
+
code_bodies.append(ast.ImportFrom(module='mindspore', names=[ast.alias(name='nn', asname=None)], level=0))
|
|
1312
|
+
code_bodies.append(ast.ImportFrom(module='mindspore.nn', names=[ast.alias(name='Cell', asname=None)], level=0))
|
|
1313
|
+
code_bodies.append(ast.ImportFrom(module='mindspore.ops',
|
|
1314
|
+
names=[ast.alias(name='functional', asname='F')], level=0))
|
|
1315
|
+
code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
|
|
1316
|
+
# Add user custom codes into code_bodies
|
|
1317
|
+
custom_codes = self.get_custom_codes()
|
|
1318
|
+
for code_ast in custom_codes:
|
|
1319
|
+
code_bodies.append(code_ast)
|
|
1320
|
+
code_bodies.append(ast.Expr(ast.Name("#", ast.Load())))
|
|
1321
|
+
return len(code_bodies)
|
|
1322
|
+
|
|
1323
|
+
def convert_stree_to_code_bodies(self, stree: 'SymbolTree', code_bodies: list, dividing_pos=0) -> int:
|
|
1048
1324
|
"""
|
|
1049
1325
|
Convert nodes in stree to code_bodies
|
|
1326
|
+
- Add external function asts into code_bodies
|
|
1327
|
+
- Add father class asts into code_bodies
|
|
1328
|
+
- Add import asts of symbol tree into code_bodies
|
|
1329
|
+
- Add user custom codes into code_bodies
|
|
1330
|
+
- Add class asts of symbol tree into code_bodies
|
|
1331
|
+
- Add subtrees to code_bodies
|
|
1332
|
+
"""
|
|
1333
|
+
insert_pos = dividing_pos
|
|
1334
|
+
# Add external asts into code_bodies
|
|
1335
|
+
for ast_func, import_asts in reversed(stree.get_external_ast().items()):
|
|
1336
|
+
if self.check_body_exist(ast_func, code_bodies):
|
|
1337
|
+
continue
|
|
1338
|
+
# add imports of external_ast
|
|
1339
|
+
self._tmp_import_strs.clear()
|
|
1340
|
+
for ast_import in import_asts:
|
|
1341
|
+
if not self.check_body_exist(ast_import, code_bodies):
|
|
1342
|
+
code_bodies.insert(insert_pos, ast_import)
|
|
1343
|
+
insert_pos += 1
|
|
1344
|
+
# add external_ast
|
|
1345
|
+
code_bodies.insert(insert_pos, ast_func)
|
|
1346
|
+
insert_pos += 1
|
|
1347
|
+
# add divide
|
|
1348
|
+
code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
|
|
1349
|
+
insert_pos += 1
|
|
1050
1350
|
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1351
|
+
# Add father class asts into code_bodies
|
|
1352
|
+
for ast_class, import_asts in stree.get_father_class_ast().items():
|
|
1353
|
+
if self.check_body_exist(ast_class, code_bodies):
|
|
1354
|
+
continue
|
|
1355
|
+
# add imports of father class
|
|
1356
|
+
self._tmp_import_strs.clear()
|
|
1357
|
+
for ast_import in import_asts:
|
|
1358
|
+
if not self.check_body_exist(ast_import, code_bodies):
|
|
1359
|
+
code_bodies.insert(insert_pos, ast_import)
|
|
1360
|
+
insert_pos += 1
|
|
1361
|
+
# add ast of father class
|
|
1362
|
+
code_bodies.insert(insert_pos, ast_class)
|
|
1363
|
+
insert_pos += 1
|
|
1364
|
+
# add divide
|
|
1365
|
+
code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
|
|
1366
|
+
insert_pos += 1
|
|
1367
|
+
|
|
1368
|
+
# external functions and father class are above the dividing_pos to support deduplication.
|
|
1369
|
+
dividing_pos = insert_pos
|
|
1370
|
+
|
|
1371
|
+
# Add import asts of symbol tree into code_bodies
|
|
1372
|
+
self._tmp_import_strs.clear()
|
|
1061
1373
|
for body in stree.get_import_asts():
|
|
1062
1374
|
if not self.check_body_exist(body, code_bodies):
|
|
1063
1375
|
code_bodies.insert(insert_pos, body)
|
|
1064
1376
|
insert_pos += 1
|
|
1065
1377
|
|
|
1066
|
-
# Add class
|
|
1378
|
+
# Add class asts of symbol tree into code_bodies
|
|
1067
1379
|
if stree.get_module_ast():
|
|
1068
1380
|
for body in stree.get_module_ast().body:
|
|
1069
1381
|
if self.check_body_exist(body, code_bodies):
|
|
1070
1382
|
continue
|
|
1071
|
-
if isinstance(body, (ast.ClassDef, ast.FunctionDef)):
|
|
1072
|
-
code_bodies.insert(insert_pos, body)
|
|
1073
|
-
else:
|
|
1074
|
-
code_bodies.append(body)
|
|
1075
|
-
|
|
1076
|
-
# Add father class asts into code_bodies
|
|
1077
|
-
for body in reversed(stree.get_father_class_ast()):
|
|
1078
|
-
if self.check_body_exist(body, code_bodies):
|
|
1079
|
-
# remove exist ast in old position, then insert ast to upper position
|
|
1080
|
-
if sys.version_info >= (3, 9):
|
|
1081
|
-
exist_ast = AstClassFinder(ast.Module(body=code_bodies, type_ignores=[])).find_all(body.name)[0]
|
|
1082
|
-
else:
|
|
1083
|
-
exist_ast = AstClassFinder(ast.Module(body=code_bodies)).find_all(body.name)[0]
|
|
1084
|
-
code_bodies.remove(exist_ast)
|
|
1085
|
-
code_bodies.insert(insert_pos, body)
|
|
1086
|
-
|
|
1087
|
-
# Add external asts into code_bodies
|
|
1088
|
-
for body in stree.get_external_ast():
|
|
1089
|
-
if not self.check_body_exist(body, code_bodies):
|
|
1090
1383
|
code_bodies.insert(insert_pos, body)
|
|
1091
1384
|
insert_pos += 1
|
|
1092
1385
|
|
|
1386
|
+
# add divide
|
|
1387
|
+
code_bodies.insert(insert_pos, ast.Expr(ast.Name("#", ast.Load())))
|
|
1388
|
+
insert_pos += 1
|
|
1389
|
+
|
|
1093
1390
|
# Add subtrees to code_bodies
|
|
1094
1391
|
for node in stree.get_tree_nodes():
|
|
1095
1392
|
sub_stree = node.symbol_tree
|
|
1096
|
-
# Ignore TreeNode create by function in the class
|
|
1097
|
-
if isinstance(sub_stree.get_module_ast(), ast.FunctionDef):
|
|
1098
|
-
continue
|
|
1099
1393
|
# For the unmodified class, update class name to name of first class
|
|
1100
|
-
if self.
|
|
1394
|
+
if self.update_unmodified_stree(sub_stree, code_bodies):
|
|
1101
1395
|
continue
|
|
1102
|
-
self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies,
|
|
1396
|
+
dividing_pos = self.convert_stree_to_code_bodies(node.symbol_tree, code_bodies, dividing_pos)
|
|
1397
|
+
|
|
1398
|
+
# return new dividing position
|
|
1399
|
+
return dividing_pos
|
|
1103
1400
|
|
|
1104
1401
|
def get_code(self) -> str:
|
|
1105
1402
|
"""
|
|
@@ -1112,15 +1409,18 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1112
1409
|
self._tmp_unmodified_strees.clear()
|
|
1113
1410
|
self._tmp_replacers.clear()
|
|
1114
1411
|
code_bodies = []
|
|
1115
|
-
self.
|
|
1412
|
+
begin_pos = self.init_code_bodies(code_bodies)
|
|
1413
|
+
self.convert_stree_to_code_bodies(self, code_bodies, begin_pos)
|
|
1414
|
+
self.deduplicate_unmodified_stree(code_bodies)
|
|
1116
1415
|
if sys.version_info >= (3, 9):
|
|
1117
1416
|
gencode_module = ast.Module(body=code_bodies, type_ignores=[])
|
|
1118
1417
|
else:
|
|
1119
1418
|
gencode_module = ast.Module(body=code_bodies)
|
|
1120
1419
|
SymbolTree._remove_unused_import(gencode_module)
|
|
1420
|
+
self._process_duplicate_name_modules(gencode_module)
|
|
1121
1421
|
SymbolTree._remove_duplicated_import(gencode_module)
|
|
1422
|
+
SymbolTree._remove_arg_annotations(gencode_module)
|
|
1122
1423
|
ast.fix_missing_locations(self._module_ast)
|
|
1123
|
-
IfFixer().fix(gencode_module)
|
|
1124
1424
|
code = astunparse.unparse(gencode_module)
|
|
1125
1425
|
# Revert the class name to its original state
|
|
1126
1426
|
for replacer in self._tmp_replacers:
|
|
@@ -1137,6 +1437,9 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1137
1437
|
cls = self._get_cls_through_file()
|
|
1138
1438
|
new_net = cls(self._origin_network)
|
|
1139
1439
|
self._merge_origin_property(new_net)
|
|
1440
|
+
# update parameters' names to fix duplicated names bug
|
|
1441
|
+
# which occurs after inserting cell to celllist/sequentialcell
|
|
1442
|
+
new_net.update_parameters_name()
|
|
1140
1443
|
return new_net
|
|
1141
1444
|
|
|
1142
1445
|
def set_saved_file_name(self, file_name: str):
|
|
@@ -1157,42 +1460,189 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1157
1460
|
f.write(source.encode('utf-8'))
|
|
1158
1461
|
f.flush()
|
|
1159
1462
|
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
"""
|
|
1163
|
-
if
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1463
|
+
|
|
1464
|
+
def flatten_nodes(self, node, erase_another_branch: bool = False, erase_nodes_after_return: bool = False):
|
|
1465
|
+
"""Flatten nodes in ControlFlow node."""
|
|
1466
|
+
if not isinstance(node, ControlFlow):
|
|
1467
|
+
raise ValueError(f"For flatten_nodes, the type of node can only be ControlFlow, but got {type(node)}.")
|
|
1468
|
+
upper_node_manager = node.get_node_manager()
|
|
1469
|
+
if isinstance(upper_node_manager, (SymbolTree, CallFunction)):
|
|
1470
|
+
ast_bodies = upper_node_manager.get_manager_ast().body
|
|
1471
|
+
elif isinstance(upper_node_manager, ControlFlow):
|
|
1472
|
+
ast_bodies = upper_node_manager.get_manager_ast()
|
|
1473
|
+
else:
|
|
1474
|
+
raise ValueError("For flatten_nodes, the node can only be contained in [SymbolTree, CallFunction, "
|
|
1475
|
+
f"ControlFlow], but the node is in {type(upper_node_manager)}.")
|
|
1476
|
+
base_node = node.orelse_node if node.orelse_node else node.body_node
|
|
1477
|
+
for n in node.nodes()[:]:
|
|
1478
|
+
self.erase_node(n)
|
|
1479
|
+
self.insert_node(n, base_node, False, upper_node_manager, False)
|
|
1480
|
+
AstModifier.insert_ast_to_bodies(ast_bodies, n.get_ast(), base_node.get_ast(), False)
|
|
1481
|
+
base_node = n
|
|
1482
|
+
self.erase_node(node)
|
|
1483
|
+
# remove another branch
|
|
1484
|
+
if erase_another_branch:
|
|
1485
|
+
if node.is_orelse:
|
|
1486
|
+
self.erase_node(node.body_node)
|
|
1487
|
+
elif node.orelse_node is not None:
|
|
1488
|
+
self.erase_node(node.orelse_node)
|
|
1489
|
+
# remove nodes after return node
|
|
1490
|
+
if erase_nodes_after_return:
|
|
1491
|
+
has_return = False
|
|
1492
|
+
for n in upper_node_manager.nodes():
|
|
1493
|
+
if has_return:
|
|
1494
|
+
logger.warning(f"Node {n.get_name()} which is behind the flatten return node is "
|
|
1495
|
+
f"automatically erased.")
|
|
1496
|
+
self.erase_node(n)
|
|
1497
|
+
elif n.get_node_type() == NodeType.Output:
|
|
1498
|
+
has_return = True
|
|
1499
|
+
|
|
1500
|
+
def eval_ast_result(self, ast_node: ast.AST) -> (bool, bool):
|
|
1501
|
+
"""
|
|
1502
|
+
Eval ast_node and get result, only used in control flow node.
|
|
1503
|
+
"""
|
|
1504
|
+
# ast.Constant can be check without eval
|
|
1505
|
+
if isinstance(ast_node, ast.Constant):
|
|
1506
|
+
return True, bool(ast.value)
|
|
1507
|
+
# Get the module where the code of ast_node is located
|
|
1508
|
+
file_path = inspect.getfile(type(self.get_origin_network()))
|
|
1509
|
+
module = None
|
|
1510
|
+
for m in list(sys.modules.values()):
|
|
1511
|
+
if hasattr(m, "__file__") and m.__file__ and os.path.normcase(m.__file__) == os.path.normcase(file_path):
|
|
1512
|
+
module = m
|
|
1513
|
+
break
|
|
1514
|
+
if not module:
|
|
1515
|
+
logger.warning("Failed to get module of ast_node.")
|
|
1516
|
+
return False, False
|
|
1517
|
+
# eval ast_node and get result
|
|
1518
|
+
logger.debug(f"Eval ast node: {astunparse.unparse(ast_node)}")
|
|
1519
|
+
ast_expr = ast.Expression(ast_node)
|
|
1520
|
+
ast_expr = ast.fix_missing_locations(ast_expr)
|
|
1521
|
+
try:
|
|
1522
|
+
# eval with ast make this operation free of instruction injection
|
|
1523
|
+
# pylint: disable=eval-used
|
|
1524
|
+
result = eval(compile(ast_expr, "eval_ast_result", "eval"), {**globals(), **module.__dict__}, locals())
|
|
1525
|
+
except Exception as e: # pylint: disable=broad-except
|
|
1526
|
+
logger.debug(f"Cannot get result of ast_node by eval, err:{e}")
|
|
1527
|
+
return False, False
|
|
1528
|
+
logger.debug(f"Eval ast result success, result: {result}")
|
|
1529
|
+
return True, bool(result)
|
|
1530
|
+
|
|
1531
|
+
def flatten_static_if_control_flow(self):
|
|
1532
|
+
"""
|
|
1533
|
+
For static if control flow, flatten codes in branch which will be executed and erase another branch.
|
|
1534
|
+
"""
|
|
1535
|
+
for node in self.all_nodes()[:]:
|
|
1536
|
+
if not node.get_belong_symbol_tree():
|
|
1537
|
+
# the node has been erased
|
|
1538
|
+
continue
|
|
1539
|
+
if isinstance(node, ControlFlow) and node.test_result is not None:
|
|
1540
|
+
stree = node.get_belong_symbol_tree()
|
|
1541
|
+
if node.test_result:
|
|
1542
|
+
stree.flatten_nodes(node.body_node, True, True)
|
|
1543
|
+
else:
|
|
1544
|
+
if node.orelse_node is not None:
|
|
1545
|
+
stree.flatten_nodes(node.orelse_node, True, True)
|
|
1546
|
+
else:
|
|
1547
|
+
stree.erase_node(node.body_node)
|
|
1548
|
+
|
|
1549
|
+
def add_custom_codes(self, code: str):
|
|
1550
|
+
"""Add user custom codes"""
|
|
1551
|
+
code_ast = ast.parse(code)
|
|
1552
|
+
self._custom_codes.extend(code_ast.body)
|
|
1553
|
+
|
|
1554
|
+
def get_custom_codes(self) -> List[ast.AST]:
|
|
1555
|
+
"""Add user custom codes"""
|
|
1556
|
+
return self._custom_codes
|
|
1557
|
+
|
|
1558
|
+
def save_file_path_to_sys(self, level_num, file_path, belonging_ast: ast.AST = None):
|
|
1559
|
+
"""
|
|
1560
|
+
Save file path into stree._import_asts. `level_num` is used when level exist in ast.ImportFrom.
|
|
1561
|
+
|
|
1562
|
+
When level_num = 0(e.g. from xxx import yyy), current path will be saved.
|
|
1563
|
+
When level_num = 1(e.g. from .xxx import yyy), current path will be saved.
|
|
1564
|
+
When level_num = 2(e.g. from ..xxx import yyy), the path one level above the current path will be saved.
|
|
1565
|
+
"""
|
|
1566
|
+
file_path = os.path.dirname(os.path.abspath(file_path))
|
|
1567
|
+
file_path = os.path.normcase(file_path)
|
|
1568
|
+
file_path = os.path.normpath(file_path)
|
|
1569
|
+
if level_num > 1:
|
|
1570
|
+
for _ in range(level_num - 1):
|
|
1571
|
+
file_path = os.path.dirname(file_path)
|
|
1572
|
+
sys_path_append_ast = ast.parse(f"sys.path.insert(0, r'{file_path}')").body[0]
|
|
1573
|
+
# add imports to import_asts of belonging_ast
|
|
1574
|
+
import_asts = self._get_imports_list_of_ast(belonging_ast)
|
|
1575
|
+
import_asts.append(ast.Import([ast.alias(name='sys', asname=None)]))
|
|
1576
|
+
import_asts.append(sys_path_append_ast)
|
|
1577
|
+
|
|
1578
|
+
def save_imports_from_file(self, file_path, belonging_ast: ast.AST = None):
|
|
1579
|
+
"""Save imports from file"""
|
|
1580
|
+
self.save_file_path_to_sys(0, file_path, belonging_ast)
|
|
1581
|
+
if not os.path.exists(file_path):
|
|
1582
|
+
raise RuntimeError(f"For MindSpore Rewrite, in module parser, file {file_path} not exist.")
|
|
1583
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
|
1584
|
+
source_code = f.read()
|
|
1585
|
+
import_nodes = AstImportFinder(ast.parse(dedent(source_code))).get_import_node()
|
|
1586
|
+
if not import_nodes:
|
|
1587
|
+
return
|
|
1588
|
+
# add imports to import_asts of belonging_ast
|
|
1589
|
+
import_asts = self._get_imports_list_of_ast(belonging_ast)
|
|
1590
|
+
for import_node in import_nodes:
|
|
1591
|
+
import_node = SymbolTree._process_relative_import(import_node, file_path)
|
|
1592
|
+
if import_node:
|
|
1593
|
+
import_asts.append(import_node)
|
|
1594
|
+
|
|
1595
|
+
def add_import(self, module: types.ModuleType, name: str, belonging_ast: None):
|
|
1596
|
+
"""add codes: from `module` import `name`"""
|
|
1597
|
+
if not isinstance(module, types.ModuleType):
|
|
1598
|
+
raise TypeError(f"For add_import, module should be ModuleType, but got {type(module)}")
|
|
1599
|
+
if not hasattr(module, name):
|
|
1600
|
+
logger.info(f"module {module.__name__} doesn't have attr '{name}', it may be a local variable.")
|
|
1601
|
+
return
|
|
1602
|
+
# add imports to import_asts of belonging_ast
|
|
1603
|
+
import_asts = self._get_imports_list_of_ast(belonging_ast)
|
|
1604
|
+
if module.__name__ == "__main__":
|
|
1605
|
+
# get attr from module instead of import to avoid duplicate execution of __main__ module
|
|
1606
|
+
code = f"{name} = getattr(sys.modules['__main__'], '{name}')"
|
|
1607
|
+
code_ast = ast.parse(code).body[0]
|
|
1608
|
+
import_asts.append(code_ast)
|
|
1609
|
+
elif module.__name__ == "builtins":
|
|
1610
|
+
# built-in functions are not need to be imported
|
|
1611
|
+
pass
|
|
1170
1612
|
else:
|
|
1171
|
-
#
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1613
|
+
# add import of obj to ast
|
|
1614
|
+
func_file_path = inspect.getabsfile(module)
|
|
1615
|
+
func_file_path = os.path.normcase(func_file_path)
|
|
1616
|
+
prefix_paths = []
|
|
1617
|
+
for path in sys.path:
|
|
1618
|
+
path = os.path.normcase(path)
|
|
1619
|
+
if func_file_path.startswith(path):
|
|
1620
|
+
prefix_paths.append(path)
|
|
1621
|
+
prefix_paths.sort(key=len, reverse=True)
|
|
1622
|
+
for path in prefix_paths:
|
|
1623
|
+
import_path = func_file_path[len(path):]
|
|
1624
|
+
import_str = import_path.replace(os.path.sep, '.')
|
|
1625
|
+
import_str = import_str[1:] # remove first '.'
|
|
1626
|
+
mod = import_str.rsplit('.', 1)[0]
|
|
1627
|
+
if SymbolTree._check_import(func_file_path[:len(path)], mod):
|
|
1628
|
+
import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
|
|
1629
|
+
import_asts.append(import_node)
|
|
1630
|
+
break
|
|
1185
1631
|
else:
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1632
|
+
self.save_file_path_to_sys(0, func_file_path, belonging_ast)
|
|
1633
|
+
mod = os.path.basename(func_file_path).rsplit('.')[0]
|
|
1634
|
+
import_node = ast.ImportFrom(module=mod, names=[ast.alias(name=name, asname=None)], level=0)
|
|
1635
|
+
import_asts.append(import_node)
|
|
1636
|
+
|
|
1637
|
+
def _get_imports_list_of_ast(self, belonging_ast: ast.AST):
|
|
1638
|
+
# get import_asts of belonging_ast
|
|
1639
|
+
import_asts = self._import_asts
|
|
1640
|
+
if belonging_ast is not None:
|
|
1641
|
+
if belonging_ast in self._father_class_ast:
|
|
1642
|
+
import_asts = self._father_class_ast.get(belonging_ast)
|
|
1643
|
+
elif belonging_ast in self._external_ast:
|
|
1644
|
+
import_asts = self._external_ast.get(belonging_ast)
|
|
1645
|
+
return import_asts
|
|
1196
1646
|
|
|
1197
1647
|
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1198
1648
|
if isinstance(node_or_name, str):
|
|
@@ -1265,7 +1715,7 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1265
1715
|
time.sleep(0.5)
|
|
1266
1716
|
i += 1
|
|
1267
1717
|
if not tmp_module:
|
|
1268
|
-
|
|
1718
|
+
raise ImportError(f"load module {tmp_module_name} failed.")
|
|
1269
1719
|
# Save new module to sys.modules to support inspect.getsource().
|
|
1270
1720
|
sys.modules[tmp_module_name] = tmp_module
|
|
1271
1721
|
network_cls = getattr(tmp_module, self._opt_cls_name)
|
|
@@ -1295,6 +1745,75 @@ class SymbolTree(Observer, Observable, NodeManager):
|
|
|
1295
1745
|
for c in cells:
|
|
1296
1746
|
new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
|
|
1297
1747
|
# merge primitives
|
|
1748
|
+
# pylint: disable=protected-access
|
|
1298
1749
|
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1299
1750
|
for p in primitives:
|
|
1300
|
-
new_net._primitives[p] = self._origin_network._primitives[p]
|
|
1751
|
+
new_net._primitives[p] = self._origin_network._primitives[p] # pylint: disable=protected-access
|
|
1752
|
+
|
|
1753
|
+
def _process_duplicate_name_modules(self, module_ast: ast.Module):
|
|
1754
|
+
"""Adjust names of imported modules with same name and different import path."""
|
|
1755
|
+
# {name1: [path1, path2, ...], ...}
|
|
1756
|
+
name_path_dict: Dict[str, List[str]] = {}
|
|
1757
|
+
# names of modules need to be suffixed: {name1: suffixed_name1, ...}
|
|
1758
|
+
name_need_suffix: Dict[str, str] = {}
|
|
1759
|
+
# used to record replace actions in ast.ImportFrom
|
|
1760
|
+
import_replacer = AstReplacer(None)
|
|
1761
|
+
self._tmp_replacers.append(import_replacer)
|
|
1762
|
+
|
|
1763
|
+
def suffix_alias(alias: ast.alias, suffix: int):
|
|
1764
|
+
"""suffix the name of alias in ast.ImportFrom"""
|
|
1765
|
+
new_name = f"{alias.asname}_{suffix}" if alias.asname else f"{alias.name}_{suffix}"
|
|
1766
|
+
import_replacer._trace.append((alias, 'asname', alias.asname, new_name)) # pylint: disable=protected-access
|
|
1767
|
+
alias.asname = new_name
|
|
1768
|
+
return new_name
|
|
1769
|
+
|
|
1770
|
+
def is_divider(ast_node):
|
|
1771
|
+
"""judge if ast node is divider of new class or function by checking ast.Expr of '#'."""
|
|
1772
|
+
return isinstance(ast_node, ast.Expr) and isinstance(ast_node.value, ast.Name) and ast_node.value.id == '#'
|
|
1773
|
+
|
|
1774
|
+
def record_imports(ast_node: ast.ImportFrom):
|
|
1775
|
+
"""record name and path of imported modules to find the duplicate name modules."""
|
|
1776
|
+
for alias in ast_node.names[:]:
|
|
1777
|
+
name = alias.asname if alias.asname else alias.name
|
|
1778
|
+
if name == '*':
|
|
1779
|
+
continue
|
|
1780
|
+
# current name is firstly imported, just record it
|
|
1781
|
+
if name not in name_path_dict:
|
|
1782
|
+
name_path_dict[name] = [ast_node.module]
|
|
1783
|
+
continue
|
|
1784
|
+
# current name is imported before, check whether it is a duplicated name
|
|
1785
|
+
for idx, path in enumerate(name_path_dict[name]):
|
|
1786
|
+
if path.startswith(ast_node.module):
|
|
1787
|
+
# e.g. origin code is 'from a.b.c import A' and new code is 'from a.b import A'
|
|
1788
|
+
# then we update name_path_dict[name][idx] from 'a.b.c' to 'a.b' and update name to A_{idx}
|
|
1789
|
+
name_path_dict[name][idx] = ast_node.module
|
|
1790
|
+
if idx > 0:
|
|
1791
|
+
name_need_suffix[name] = suffix_alias(alias, idx)
|
|
1792
|
+
break
|
|
1793
|
+
elif ast_node.module.startswith(path):
|
|
1794
|
+
# e.g. origin code is 'from a.b import A' and new code is 'from a.b.c import A'
|
|
1795
|
+
# then we just need to update name to A_{idx}
|
|
1796
|
+
if idx > 0:
|
|
1797
|
+
name_need_suffix[name] = suffix_alias(alias, idx)
|
|
1798
|
+
break
|
|
1799
|
+
else:
|
|
1800
|
+
# current name is imported from a new path, save the path and update the name
|
|
1801
|
+
name_path_dict[name].append(ast_node.module)
|
|
1802
|
+
name_need_suffix[name] = suffix_alias(alias, len(name_path_dict[name]) - 1)
|
|
1803
|
+
|
|
1804
|
+
def suffix_names_in_ast(ast_node: Union[ast.ClassDef, ast.FunctionDef]):
|
|
1805
|
+
"""suffix names in ast.ClassDef or ast.FunctionDef"""
|
|
1806
|
+
if not name_need_suffix:
|
|
1807
|
+
return
|
|
1808
|
+
name_replacer = AstReplacer(ast_node)
|
|
1809
|
+
self._tmp_replacers.append(name_replacer)
|
|
1810
|
+
for name, new_name in name_need_suffix.items():
|
|
1811
|
+
name_replacer.replace_all(name, new_name)
|
|
1812
|
+
|
|
1813
|
+
for ast_node in module_ast.body:
|
|
1814
|
+
if isinstance(ast_node, ast.ImportFrom):
|
|
1815
|
+
record_imports(ast_node)
|
|
1816
|
+
if isinstance(ast_node, (ast.ClassDef, ast.FunctionDef)):
|
|
1817
|
+
suffix_names_in_ast(ast_node)
|
|
1818
|
+
if is_divider(ast_node):
|
|
1819
|
+
name_need_suffix.clear()
|