mindspore 2.1.0__cp39-none-any.whl → 2.2.11__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Ast utils for create or update ast node."""
|
|
16
|
-
from typing import Optional
|
|
16
|
+
from typing import Optional, List
|
|
17
17
|
import ast
|
|
18
18
|
|
|
19
19
|
from ..api.scoped_value import ScopedValue, ValueType
|
|
@@ -34,20 +34,15 @@ class AstModifier(ast.NodeTransformer):
|
|
|
34
34
|
Returns:
|
|
35
35
|
A bool if to_erase-node been found and been erased.
|
|
36
36
|
"""
|
|
37
|
-
|
|
37
|
+
return AstModifier.erase_ast_from_bodies(ast_func.body, to_erase)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def erase_ast_from_bodies(ast_bodies: List[ast.AST], to_erase: ast.AST) -> bool:
|
|
41
|
+
"""Erase ast node from ast bodies."""
|
|
42
|
+
for body in ast_bodies:
|
|
38
43
|
if id(body) == id(to_erase):
|
|
39
|
-
|
|
44
|
+
ast_bodies.remove(body)
|
|
40
45
|
return True
|
|
41
|
-
# hardcode for ast.If
|
|
42
|
-
if isinstance(body, ast.If):
|
|
43
|
-
for if_body in body.body:
|
|
44
|
-
if id(if_body) == id(to_erase):
|
|
45
|
-
body.body.remove(if_body)
|
|
46
|
-
return True
|
|
47
|
-
for else_body in body.orelse:
|
|
48
|
-
if id(else_body) == id(to_erase):
|
|
49
|
-
body.orelse.remove(else_body)
|
|
50
|
-
return True
|
|
51
46
|
return False
|
|
52
47
|
|
|
53
48
|
@staticmethod
|
|
@@ -147,13 +142,6 @@ class AstModifier(ast.NodeTransformer):
|
|
|
147
142
|
RuntimeError: If 'index_ast' is not contained in 'ast_func'.
|
|
148
143
|
"""
|
|
149
144
|
assign = AstModifier.create_call_assign(targets, expr, args, kwargs)
|
|
150
|
-
arguments: ast.arguments = ast_func.args
|
|
151
|
-
if arguments.args:
|
|
152
|
-
for arg in arguments.args:
|
|
153
|
-
if id(arg) == id(index_ast):
|
|
154
|
-
ast_func.body.insert(0, assign)
|
|
155
|
-
ast.fix_missing_locations(ast_func)
|
|
156
|
-
return assign
|
|
157
145
|
return AstModifier.insert_assign_ast_to_function(ast_func, assign, index_ast, insert_before)
|
|
158
146
|
|
|
159
147
|
@staticmethod
|
|
@@ -177,49 +165,63 @@ class AstModifier(ast.NodeTransformer):
|
|
|
177
165
|
Raises:
|
|
178
166
|
RuntimeError: If 'index_ast' is not contained in 'ast_func'.
|
|
179
167
|
"""
|
|
180
|
-
|
|
181
|
-
ast_func.body.append(ast_assign)
|
|
182
|
-
ast.fix_missing_locations(ast_func)
|
|
183
|
-
return ast_assign
|
|
168
|
+
# Insert ast at the frontmost position of function body when index_ast is an argument of function
|
|
184
169
|
arguments: ast.arguments = ast_func.args
|
|
185
|
-
if arguments.args:
|
|
170
|
+
if index_ast and arguments.args:
|
|
186
171
|
for arg in arguments.args:
|
|
187
172
|
if id(arg) == id(index_ast):
|
|
188
173
|
ast_func.body.insert(0, ast_assign)
|
|
189
174
|
ast.fix_missing_locations(ast_func)
|
|
190
175
|
return ast_assign
|
|
191
|
-
|
|
192
|
-
|
|
176
|
+
# Insert ast at position specified by index_ast in function body
|
|
177
|
+
ast_assign = AstModifier.insert_assign_ast_to_bodies(ast_func.body, ast_assign, index_ast, insert_before)
|
|
178
|
+
ast.fix_missing_locations(ast_assign)
|
|
179
|
+
return ast_assign
|
|
180
|
+
|
|
181
|
+
@staticmethod
|
|
182
|
+
def insert_assign_ast_to_bodies(ast_bodies: List[ast.AST], ast_assign: ast.Assign,
|
|
183
|
+
index_ast: Optional[ast.AST] = None, insert_before=True) -> ast.AST:
|
|
184
|
+
"""Insert ast at position specified by index_ast of ast_bodies"""
|
|
185
|
+
# Append ast_assign to ast_bodies when index_ast is None
|
|
186
|
+
if index_ast is None:
|
|
187
|
+
ast_bodies.append(ast_assign)
|
|
188
|
+
return ast_assign
|
|
189
|
+
# Append ast_assign to ast_bodies
|
|
190
|
+
for index, body in enumerate(ast_bodies):
|
|
193
191
|
if id(body) == id(index_ast):
|
|
194
|
-
if insert_before:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
192
|
+
if not insert_before:
|
|
193
|
+
index += 1
|
|
194
|
+
ast_bodies.insert(index, ast_assign)
|
|
195
|
+
ast.fix_missing_locations(body)
|
|
196
|
+
break
|
|
197
|
+
else:
|
|
198
|
+
raise ValueError("insert position is not contained in ast_bodies")
|
|
199
|
+
return ast_assign
|
|
200
|
+
|
|
201
|
+
@staticmethod
|
|
202
|
+
def append_arg_to_function(ast_func: ast.FunctionDef, ast_arg: ast.arg) -> ast.AST:
|
|
203
|
+
"""
|
|
204
|
+
Append an ast.arg to an ast.FunctionDef (e.g. self.construct).
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
ast_func (ast.FunctionDef): An instance of ast.FunctionDef which is "construct" function of network.
|
|
208
|
+
ast_arg (ast.arg): An instance of ast.arg to be inserted in.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
An instance of ast.arg which has been appended to 'ast_func'.
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
RuntimeError: If 'ast_arg' is not an instance of ast_arg.
|
|
215
|
+
"""
|
|
216
|
+
if not isinstance(ast_arg, ast.arg):
|
|
217
|
+
raise RuntimeError("ast_arg should be an instance of ast.arg.")
|
|
218
|
+
arguments: ast.arguments = ast_func.args
|
|
219
|
+
args: [ast.arg] = arguments.args
|
|
220
|
+
args.append(ast_arg)
|
|
221
|
+
defaults = arguments.defaults
|
|
222
|
+
arg_default = ast.Constant(value=None, kind=None)
|
|
223
|
+
defaults.append(arg_default)
|
|
224
|
+
return ast_arg
|
|
223
225
|
|
|
224
226
|
@staticmethod
|
|
225
227
|
def append_global_vars_expr_to_init(init_func: ast.FunctionDef, targets: [ScopedValue],
|
|
@@ -241,7 +243,7 @@ class AstModifier(ast.NodeTransformer):
|
|
|
241
243
|
An instance of ast.Assign which has been appended to 'init_func'.
|
|
242
244
|
"""
|
|
243
245
|
return AstModifier.insert_assign_to_function(init_func, targets=targets,
|
|
244
|
-
expr=ScopedValue(ValueType.NamingValue, "", "
|
|
246
|
+
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
245
247
|
args=[ScopedValue(ValueType.NamingValue, "obj"),
|
|
246
248
|
ScopedValue.create_variable_value(field)])
|
|
247
249
|
|
|
@@ -265,23 +267,30 @@ class AstModifier(ast.NodeTransformer):
|
|
|
265
267
|
RuntimeError: If 'targets' is None.
|
|
266
268
|
RuntimeError: If value_type of element of 'targets' is not ValueType.NamingValue.
|
|
267
269
|
|
|
268
|
-
RuntimeError: If length of 'targets' is not 1. Multi-targets will be support in the future.
|
|
269
270
|
"""
|
|
270
|
-
if targets is None
|
|
271
|
-
raise RuntimeError("
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
271
|
+
if targets is None:
|
|
272
|
+
raise RuntimeError("'Targets should not be None.")
|
|
273
|
+
targets_list = []
|
|
274
|
+
for target in targets:
|
|
275
|
+
if target.type != ValueType.NamingValue:
|
|
276
|
+
raise RuntimeError("Target must be a right-value, got: ", target)
|
|
277
|
+
if target.scope:
|
|
278
|
+
ast_target = ast.Attribute(ast.Name(target.scope, ast.Load()), target.value, ast.Store())
|
|
279
|
+
else:
|
|
280
|
+
ast_target = ast.Name(target.value, ast.Store())
|
|
281
|
+
targets_list.append(ast_target)
|
|
278
282
|
call = AstModifier.create_call(expr, args, kwargs)
|
|
279
|
-
|
|
283
|
+
|
|
284
|
+
if len(targets) == 1:
|
|
285
|
+
result = ast.Assign(targets=[targets_list[0]], value=call)
|
|
286
|
+
elif len(targets) > 1:
|
|
287
|
+
ast_targets = ast.Tuple(elts=targets_list, ctx=ast.Store())
|
|
288
|
+
result = ast.Assign(targets=[ast_targets], value=call)
|
|
280
289
|
ast.fix_missing_locations(result)
|
|
281
290
|
return result
|
|
282
291
|
|
|
283
292
|
@staticmethod
|
|
284
|
-
def
|
|
293
|
+
def _create_arg_by_constant_value(value: ScopedValue):
|
|
285
294
|
"""
|
|
286
295
|
Create an instance of ast.Constant.
|
|
287
296
|
|
|
@@ -290,17 +299,16 @@ class AstModifier(ast.NodeTransformer):
|
|
|
290
299
|
|
|
291
300
|
Raises:
|
|
292
301
|
RuntimeError: if scope of value is not empty.
|
|
293
|
-
TypeError: type of arg not
|
|
302
|
+
TypeError: type of arg is not ValueType.ConstantValue
|
|
294
303
|
|
|
295
304
|
Returns:
|
|
296
305
|
ast.Constant: An instance of ast.Constant
|
|
297
306
|
"""
|
|
298
|
-
if value.type
|
|
307
|
+
if value.type == ValueType.ConstantValue:
|
|
299
308
|
if value.scope:
|
|
300
309
|
raise RuntimeError("For arg the scope should be empty")
|
|
301
310
|
return ast.Constant(value=value.value, kind=None)
|
|
302
|
-
raise TypeError("Type of arg only support
|
|
303
|
-
f" ValueType.StringValue], but got {type(value)}")
|
|
311
|
+
raise TypeError("Type of arg only support ValueType.ConstantValue, but got {type(value)}")
|
|
304
312
|
|
|
305
313
|
@staticmethod
|
|
306
314
|
def _create_list_or_tuple(value: ScopedValue):
|
|
@@ -315,7 +323,7 @@ class AstModifier(ast.NodeTransformer):
|
|
|
315
323
|
"""
|
|
316
324
|
elts = []
|
|
317
325
|
for v in value.value:
|
|
318
|
-
elts.append(AstModifier.
|
|
326
|
+
elts.append(AstModifier._create_arg_by_constant_value(v))
|
|
319
327
|
if isinstance(value, list):
|
|
320
328
|
return ast.List(elts=elts)
|
|
321
329
|
return ast.Tuple(elts=elts)
|
|
@@ -331,22 +339,20 @@ class AstModifier(ast.NodeTransformer):
|
|
|
331
339
|
|
|
332
340
|
Raises:
|
|
333
341
|
RuntimeError: if scope of value is not empty.
|
|
334
|
-
TypeError: type of arg not
|
|
335
|
-
ValueType.ListValue, ValueType.TupleValue]
|
|
342
|
+
TypeError: type of arg is not ValueType.ConstantValue
|
|
336
343
|
|
|
337
344
|
Returns:
|
|
338
345
|
ast.keyword: a instance of ast.keyword.
|
|
339
346
|
"""
|
|
340
347
|
if value.scope:
|
|
341
348
|
raise RuntimeError("value.scope should be empty")
|
|
342
|
-
if value.type
|
|
349
|
+
if value.type == ValueType.ConstantValue:
|
|
343
350
|
v = ast.Constant(value=value.value, kind=None)
|
|
344
351
|
elif value.type in (ValueType.ListValue, ValueType.TupleValue):
|
|
345
352
|
v = AstModifier._create_list_or_tuple(value)
|
|
346
353
|
else:
|
|
347
|
-
raise TypeError("Type of keyword value only support [ValueType.
|
|
348
|
-
"ValueType.
|
|
349
|
-
f"but got {type(value)}")
|
|
354
|
+
raise TypeError("Type of keyword value only support [ValueType.ConstantValue, ValueType.ListValue, "
|
|
355
|
+
f"ValueType.TupleValue], but got {type(value)}")
|
|
350
356
|
return ast.keyword(arg=arg, value=v)
|
|
351
357
|
|
|
352
358
|
@staticmethod
|
|
@@ -371,14 +377,14 @@ class AstModifier(ast.NodeTransformer):
|
|
|
371
377
|
for arg in args:
|
|
372
378
|
if not isinstance(arg, ScopedValue):
|
|
373
379
|
raise TypeError("arg should be ScopedValue, got: ", type(arg))
|
|
374
|
-
if arg.type
|
|
375
|
-
results.append(AstModifier.
|
|
380
|
+
if arg.type == ValueType.ConstantValue:
|
|
381
|
+
results.append(AstModifier._create_arg_by_constant_value(arg))
|
|
376
382
|
elif arg.type == ValueType.NamingValue:
|
|
377
383
|
if arg.scope:
|
|
378
384
|
results.append(ast.Attribute(ast.Name(arg.scope, ast.Load()), arg.value, ast.Store()))
|
|
379
385
|
else:
|
|
380
386
|
results.append(ast.Name(arg.value, ast.Store()))
|
|
381
|
-
elif arg.type
|
|
387
|
+
elif arg.type in (ValueType.ListValue, ValueType.TupleValue):
|
|
382
388
|
results.append(AstModifier._create_list_or_tuple(arg))
|
|
383
389
|
else:
|
|
384
390
|
raise RuntimeError("Please handle custom-object first")
|
|
@@ -406,8 +412,7 @@ class AstModifier(ast.NodeTransformer):
|
|
|
406
412
|
for arg, value in kwargs.items():
|
|
407
413
|
if not isinstance(value, ScopedValue):
|
|
408
414
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
409
|
-
if value.type in (ValueType.
|
|
410
|
-
ValueType.ListValue, ValueType.TupleValue):
|
|
415
|
+
if value.type in (ValueType.ConstantValue, ValueType.ListValue, ValueType.TupleValue):
|
|
411
416
|
results.append(AstModifier._create_keyword(arg, value))
|
|
412
417
|
elif value.type == ValueType.NamingValue:
|
|
413
418
|
if value.scope:
|
|
@@ -466,7 +471,7 @@ class AstModifier(ast.NodeTransformer):
|
|
|
466
471
|
Raises:
|
|
467
472
|
TypeError: Input src_argument is not a ScopedValue
|
|
468
473
|
RuntimeError: If 'dst_ast' is an instance of ast.Constant but type of 'src_argument' is not
|
|
469
|
-
ValueType.
|
|
474
|
+
ValueType.ConstantValue.
|
|
470
475
|
RuntimeError: If 'dst_ast' is an instance of ast.Name or ast.Attribute but type of 'src_argument' is not
|
|
471
476
|
ValueType.NamingValue.
|
|
472
477
|
RuntimeError: When 'dst_ast' is an instance of ast.Name, scope of 'src_argument' is not empty.
|
|
@@ -480,27 +485,14 @@ class AstModifier(ast.NodeTransformer):
|
|
|
480
485
|
"""
|
|
481
486
|
if not isinstance(src_argument, ScopedValue):
|
|
482
487
|
raise TypeError("src_argument should be ScopedValue, got: ", type(src_argument))
|
|
483
|
-
if isinstance(dst_ast, ast.Constant):
|
|
484
|
-
|
|
485
|
-
raise RuntimeError("src_argument should be a IntValue, FloatValue or StringValue, got:",
|
|
486
|
-
str(src_argument.type))
|
|
487
|
-
dst_ast.value = src_argument.value
|
|
488
|
-
return
|
|
489
|
-
if isinstance(dst_ast, ast.Num):
|
|
490
|
-
if src_argument.type not in [ValueType.IntValue, ValueType.FloatValue]:
|
|
491
|
-
raise RuntimeError("src_argument should be a IntValue or FloatValue, but got:",
|
|
492
|
-
str(src_argument.type))
|
|
493
|
-
dst_ast.n = src_argument.value
|
|
494
|
-
return
|
|
495
|
-
if isinstance(dst_ast, ast.Str):
|
|
496
|
-
if src_argument.type not in [ValueType.StringValue]:
|
|
497
|
-
raise RuntimeError("src_argument should be a StringValue, but got:",
|
|
498
|
-
str(src_argument.type))
|
|
499
|
-
dst_ast.s = src_argument.value
|
|
488
|
+
if isinstance(dst_ast, (ast.Constant, ast.Num, ast.Str)):
|
|
489
|
+
AstModifier.update_arg_value_constant(src_argument, dst_ast)
|
|
500
490
|
return
|
|
501
491
|
if isinstance(dst_ast, ast.Name):
|
|
502
|
-
if src_argument.type not in [ValueType.NamingValue, ValueType.
|
|
503
|
-
|
|
492
|
+
if src_argument.type not in [ValueType.NamingValue, ValueType.ConstantValue]\
|
|
493
|
+
or not isinstance(src_argument.value, str):
|
|
494
|
+
raise RuntimeError("src_argument.type should be ValueType.NamingValue or ValueType.ConstantValue, "
|
|
495
|
+
"but got:", type(src_argument.value).__name__)
|
|
504
496
|
if src_argument.scope:
|
|
505
497
|
raise RuntimeError("src_argument.scope should be empty")
|
|
506
498
|
dst_ast.id = src_argument.value
|
|
@@ -523,3 +515,23 @@ class AstModifier(ast.NodeTransformer):
|
|
|
523
515
|
AstModifier.update_arg_value(src_argument.value[elt_index], elt)
|
|
524
516
|
return
|
|
525
517
|
raise RuntimeError("keyword value type is not supported", type(dst_ast))
|
|
518
|
+
|
|
519
|
+
@staticmethod
|
|
520
|
+
def update_arg_value_constant(src_argument: ScopedValue, dst_ast: ast.AST):
|
|
521
|
+
"""Update 'arg_value' of type constant by 'input_argument'"""
|
|
522
|
+
if src_argument.type != ValueType.ConstantValue:
|
|
523
|
+
raise RuntimeError("src_argument should be a ConstantValue, got:", str(src_argument.type))
|
|
524
|
+
if isinstance(dst_ast, ast.Constant):
|
|
525
|
+
dst_ast.value = src_argument.value
|
|
526
|
+
return
|
|
527
|
+
if isinstance(dst_ast, ast.Num):
|
|
528
|
+
if not isinstance(src_argument.value, (int, float)):
|
|
529
|
+
raise RuntimeError("Type of src_argument should be int or float, but got:",
|
|
530
|
+
type(src_argument.value).__name__)
|
|
531
|
+
dst_ast.n = src_argument.value
|
|
532
|
+
return
|
|
533
|
+
if isinstance(dst_ast, ast.Str):
|
|
534
|
+
if not isinstance(src_argument.value, str):
|
|
535
|
+
raise RuntimeError("Type of src_argument should be str, but got:", type(src_argument.value).__name__)
|
|
536
|
+
dst_ast.s = src_argument.value
|
|
537
|
+
return
|
|
@@ -14,14 +14,20 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Ast optimizer for flatten recursive call."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
import sys
|
|
18
|
+
from typing import Any, Tuple, List
|
|
19
|
+
import keyword
|
|
18
20
|
import ast
|
|
19
|
-
from ast import FunctionDef
|
|
20
|
-
import astunparse
|
|
21
21
|
|
|
22
22
|
from mindspore import log as logger
|
|
23
23
|
from ..common import error_str
|
|
24
24
|
|
|
25
|
+
if sys.version_info >= (3, 9):
|
|
26
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
27
|
+
else:
|
|
28
|
+
import astunparse
|
|
29
|
+
|
|
30
|
+
FLATTEN_BLACK_LIST = ["set_vertex_attr",]
|
|
25
31
|
|
|
26
32
|
class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
27
33
|
"""Ast optimizer for flatten recursive call."""
|
|
@@ -40,17 +46,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
40
46
|
ast.BoolOp: ["values"],
|
|
41
47
|
ast.UnaryOp: ["operand"],
|
|
42
48
|
ast.Compare: ["left", "comparators"],
|
|
49
|
+
ast.If: ["test"]
|
|
43
50
|
}
|
|
51
|
+
self._transform_functions = []
|
|
52
|
+
self._transform_if = False
|
|
53
|
+
self._symbol_tree = None # Used to get unique name
|
|
44
54
|
|
|
45
55
|
@staticmethod
|
|
46
|
-
def
|
|
56
|
+
def _check_flatten_black_list(node: ast.AST):
|
|
57
|
+
"""Check whether node in flatten black list"""
|
|
58
|
+
func_name = ""
|
|
59
|
+
# Get func name of node
|
|
60
|
+
if isinstance(node, ast.Call):
|
|
61
|
+
if isinstance(node.func, ast.Name):
|
|
62
|
+
func_name = node.func.id
|
|
63
|
+
elif isinstance(node.func, ast.Attribute):
|
|
64
|
+
func_name = node.func.attr
|
|
65
|
+
# Check func name of node
|
|
66
|
+
if func_name and func_name in FLATTEN_BLACK_LIST:
|
|
67
|
+
return True
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
def _generate_target_name(self, node: ast.AST, target_names):
|
|
47
71
|
"""Generate unique target name."""
|
|
48
72
|
if isinstance(node, ast.Call):
|
|
49
73
|
func = node.func
|
|
50
74
|
if isinstance(func, ast.Name):
|
|
51
|
-
target_name = func.id
|
|
75
|
+
target_name = func.id + "_var"
|
|
52
76
|
elif isinstance(func, ast.Attribute):
|
|
53
|
-
target_name = func.attr
|
|
77
|
+
target_name = func.attr + "_var"
|
|
54
78
|
else:
|
|
55
79
|
logger.info("unhandled type of func of ast.Call while generating new target name: %s ", type(func))
|
|
56
80
|
target_name = "function"
|
|
@@ -67,30 +91,33 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
67
91
|
else:
|
|
68
92
|
logger.info("unhandled type of node while generating new target name: %s ", type(node))
|
|
69
93
|
target_name = type(node).__name__.lower() + "_var"
|
|
94
|
+
# avoid python keyword
|
|
95
|
+
if keyword.iskeyword(target_name):
|
|
96
|
+
target_name = target_name + "_var"
|
|
70
97
|
suffix = 0
|
|
71
98
|
result = target_name
|
|
72
99
|
while result in target_names:
|
|
73
100
|
suffix += 1
|
|
74
101
|
result = f"{target_name}_{suffix}"
|
|
102
|
+
if self._symbol_tree:
|
|
103
|
+
result = self._symbol_tree.unique_name(result)
|
|
75
104
|
target_names.append(result)
|
|
76
105
|
return result
|
|
77
106
|
|
|
78
|
-
|
|
79
|
-
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
|
107
|
+
def _create_new_assign_node(self, node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
|
80
108
|
"""Create new assign node to be inserted into ast.FunctionDef."""
|
|
81
109
|
if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)):
|
|
82
110
|
return "", node
|
|
83
|
-
new_target_name =
|
|
111
|
+
new_target_name = self._generate_target_name(node, target_names)
|
|
84
112
|
return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
|
|
85
113
|
|
|
86
|
-
|
|
87
|
-
def _flatten_list(node_list, target_names):
|
|
114
|
+
def _flatten_list(self, node_list, target_names):
|
|
88
115
|
"""Flatten a list of node."""
|
|
89
116
|
results = list()
|
|
90
117
|
new_list = list()
|
|
91
118
|
for node in node_list:
|
|
92
119
|
if isinstance(node, ast.Call):
|
|
93
|
-
new_target, new_node =
|
|
120
|
+
new_target, new_node = self._create_new_assign_node(node, target_names)
|
|
94
121
|
results.append(new_node)
|
|
95
122
|
new_list.append(ast.Name(id=new_target, ctx=ast.Load()))
|
|
96
123
|
else:
|
|
@@ -99,6 +126,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
99
126
|
|
|
100
127
|
def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
|
|
101
128
|
"""Flatten recursive statement according to different node type."""
|
|
129
|
+
if FlattenRecursiveStmt._check_flatten_black_list(node):
|
|
130
|
+
return []
|
|
102
131
|
flatten_config = self._flatten_table.get(type(node))
|
|
103
132
|
if flatten_config is None:
|
|
104
133
|
return []
|
|
@@ -112,21 +141,21 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
112
141
|
if isinstance(todo, ast.Starred):
|
|
113
142
|
new_list.append(todo)
|
|
114
143
|
continue
|
|
115
|
-
new_target_name, new_node =
|
|
144
|
+
new_target_name, new_node = self._create_new_assign_node(todo, target_names)
|
|
116
145
|
if id(new_node) == id(todo):
|
|
117
146
|
new_list.append(todo)
|
|
118
147
|
else:
|
|
119
148
|
new_list.append(ast.Name(id=new_target_name, ctx=ast.Load()))
|
|
120
149
|
results.append(new_node)
|
|
121
150
|
if isinstance(todo, (ast.Tuple, tuple)):
|
|
122
|
-
_res, _new_list =
|
|
151
|
+
_res, _new_list = self._flatten_list(new_node.value.elts, [new_target_name])
|
|
123
152
|
new_node.value.elts = _new_list
|
|
124
153
|
results.extend(_res)
|
|
125
154
|
setattr(node, todo_name, new_list)
|
|
126
155
|
elif isinstance(todos, dict):
|
|
127
156
|
new_dict = []
|
|
128
157
|
for key, value in todos:
|
|
129
|
-
new_target_name, new_node =
|
|
158
|
+
new_target_name, new_node = self._create_new_assign_node(value, target_names)
|
|
130
159
|
if id(new_node) == id(value):
|
|
131
160
|
new_dict[key] = value
|
|
132
161
|
else:
|
|
@@ -134,16 +163,15 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
134
163
|
results.append(new_node)
|
|
135
164
|
setattr(node, todo_name, new_dict)
|
|
136
165
|
else:
|
|
137
|
-
new_target_name, new_node =
|
|
166
|
+
new_target_name, new_node = self._create_new_assign_node(todos, target_names)
|
|
138
167
|
if id(new_node) != id(todos):
|
|
139
168
|
setattr(node, todo_name, ast.Name(id=new_target_name, ctx=ast.Load()))
|
|
140
169
|
results.append(new_node)
|
|
141
170
|
return results
|
|
142
171
|
|
|
143
|
-
def
|
|
144
|
-
"""
|
|
145
|
-
for
|
|
146
|
-
child = node.body[function_index]
|
|
172
|
+
def _save_target_names(self, target_names, ast_body: List[ast.AST]):
|
|
173
|
+
"""Saving target names in ast_body before getting unique names."""
|
|
174
|
+
for child in ast_body:
|
|
147
175
|
if isinstance(child, (ast.Assign, ast.Expr)):
|
|
148
176
|
child_value = child.value
|
|
149
177
|
else:
|
|
@@ -155,7 +183,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
155
183
|
continue
|
|
156
184
|
targets = child.targets
|
|
157
185
|
for target in targets:
|
|
158
|
-
if not isinstance(target, (ast.Name, ast.Tuple)):
|
|
186
|
+
if not isinstance(target, (ast.Name, ast.Tuple, ast.List)):
|
|
159
187
|
raise RuntimeError(
|
|
160
188
|
error_str(f"currently only support ast.Name targets, but got ast type "
|
|
161
189
|
f"'{type(target).__name__}'", child_node=target, father_node=child))
|
|
@@ -163,7 +191,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
163
191
|
target_name = target.id
|
|
164
192
|
if target_name not in target_names:
|
|
165
193
|
target_names.append(target_name)
|
|
166
|
-
elif isinstance(target, ast.Tuple):
|
|
194
|
+
elif isinstance(target, (ast.Tuple, ast.List)):
|
|
167
195
|
for elt in target.elts:
|
|
168
196
|
if not isinstance(elt, ast.Name):
|
|
169
197
|
raise RuntimeError(
|
|
@@ -174,47 +202,66 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
174
202
|
if target_name not in target_names:
|
|
175
203
|
target_names.append(target_name)
|
|
176
204
|
|
|
177
|
-
def
|
|
178
|
-
"""Traverse
|
|
179
|
-
if node.name != "construct":
|
|
180
|
-
return node
|
|
181
|
-
|
|
205
|
+
def _visit_ast_bodies(self, ast_body: List[ast.AST]):
|
|
206
|
+
"""Traverse nodes in ast_body and flatten nodes recursive."""
|
|
182
207
|
target_names = []
|
|
183
|
-
self.
|
|
184
|
-
index = len(
|
|
208
|
+
self._save_target_names(target_names, ast_body)
|
|
209
|
+
index = len(ast_body) - 1
|
|
185
210
|
while index >= 0:
|
|
186
|
-
child =
|
|
211
|
+
child = ast_body[index]
|
|
187
212
|
if isinstance(child, ast.Assign):
|
|
188
213
|
stmt = child.value
|
|
189
214
|
elif isinstance(child, ast.If):
|
|
190
215
|
if isinstance(child.body[0], ast.Return) and not isinstance(child.test, ast.UnaryOp):
|
|
191
|
-
if isinstance(child.body[0].value, ast.
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
child.body =
|
|
198
|
-
|
|
199
|
-
else:
|
|
200
|
-
stmt = child
|
|
201
|
-
else:
|
|
202
|
-
stmt = child
|
|
216
|
+
if not isinstance(child.body[0].value, (ast.Name, ast.Constant)):
|
|
217
|
+
return_val_ast = child.body[0].value
|
|
218
|
+
return_name = self._generate_target_name(return_val_ast, target_names)
|
|
219
|
+
new_assign_code = f"{return_name} = {astunparse.unparse(return_val_ast)}"
|
|
220
|
+
new_assign_ast = ast.parse(new_assign_code).body[0]
|
|
221
|
+
new_return_ast = ast.parse(f"return {return_name}").body[0]
|
|
222
|
+
child.body = [new_assign_ast, new_return_ast]
|
|
223
|
+
stmt = child
|
|
203
224
|
elif isinstance(child, ast.Expr):
|
|
204
225
|
stmt = child.value
|
|
205
226
|
else:
|
|
206
227
|
stmt = child
|
|
207
228
|
results = self._flatten_statement(stmt, target_names)
|
|
208
229
|
if results:
|
|
209
|
-
results
|
|
210
|
-
|
|
211
|
-
node.body.insert(index, result)
|
|
230
|
+
for result in reversed(results):
|
|
231
|
+
ast_body.insert(index, result)
|
|
212
232
|
index += 1
|
|
213
233
|
index -= 1
|
|
234
|
+
|
|
235
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name
|
|
236
|
+
"""Traverse nodes in _transform_functions and flatten recursive nodes."""
|
|
237
|
+
if node.name not in self._transform_functions:
|
|
238
|
+
return node
|
|
239
|
+
self._visit_ast_bodies(node.body)
|
|
240
|
+
return node
|
|
241
|
+
|
|
242
|
+
def visit_If(self, node: ast.If) -> Any: # pylint: disable=invalid-name
|
|
243
|
+
"""Traverse nodes in if node and flatten recursive nodes."""
|
|
244
|
+
if not self._transform_if:
|
|
245
|
+
return node
|
|
246
|
+
self._visit_ast_bodies(node.body)
|
|
247
|
+
if node.orelse:
|
|
248
|
+
self._visit_ast_bodies(node.orelse)
|
|
214
249
|
return node
|
|
215
250
|
|
|
216
|
-
def transform(self, ast_root):
|
|
251
|
+
def transform(self, ast_root, transform_functions=None, stree=None):
|
|
217
252
|
"""Interface of FlattenRecursiveStmt."""
|
|
253
|
+
self._transform_functions = transform_functions if transform_functions else ["construct"]
|
|
254
|
+
self._transform_if = False
|
|
255
|
+
self._symbol_tree = stree
|
|
218
256
|
ast_root = self.visit(ast_root)
|
|
219
257
|
ast_root = ast.fix_missing_locations(ast_root)
|
|
220
258
|
return ast_root
|
|
259
|
+
|
|
260
|
+
def transform_if(self, ast_if, stree=None):
|
|
261
|
+
"""Interface of FlattenRecursiveStmt."""
|
|
262
|
+
self._transform_functions = []
|
|
263
|
+
self._transform_if = True
|
|
264
|
+
self._symbol_tree = stree
|
|
265
|
+
ast_if = self.visit(ast_if)
|
|
266
|
+
ast_if = ast.fix_missing_locations(ast_if)
|
|
267
|
+
return ast_if
|
|
@@ -14,8 +14,12 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Error Log for Rewrite."""
|
|
16
16
|
|
|
17
|
+
import sys
|
|
17
18
|
import ast
|
|
18
|
-
|
|
19
|
+
if sys.version_info >= (3, 9):
|
|
20
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
21
|
+
else:
|
|
22
|
+
import astunparse
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
def error_str(reason: str, child_node: ast.expr = None, father_node: ast.expr = None) -> str:
|