mindspore 2.1.0__cp37-none-any.whl → 2.2.10__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -14,37 +14,39 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
16
|
from typing import Union
|
|
17
|
+
import os
|
|
17
18
|
import ast
|
|
18
19
|
import sys
|
|
19
20
|
import inspect
|
|
20
|
-
import astunparse
|
|
21
21
|
|
|
22
22
|
from mindspore import log as logger
|
|
23
|
-
from mindspore._extends.parse.namespace import CellNamespace
|
|
24
23
|
from mindspore.nn import Cell, SequentialCell
|
|
25
|
-
from mindspore.ops import operations as P
|
|
26
24
|
from mindspore.ops import Primitive
|
|
27
|
-
from mindspore.rewrite.parser_register import ParserRegister
|
|
25
|
+
from mindspore.rewrite.parsers.parser_register import ParserRegister, reg_parser
|
|
28
26
|
from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional
|
|
29
27
|
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
30
|
-
from mindspore.rewrite.node import Node, TreeNode
|
|
31
|
-
from mindspore.rewrite.
|
|
32
|
-
from mindspore.rewrite.
|
|
28
|
+
from mindspore.rewrite.node.node import Node, TreeNode
|
|
29
|
+
from mindspore.rewrite.node.node_manager import NodeManager
|
|
30
|
+
from mindspore.rewrite.node.call_function import CallFunction
|
|
31
|
+
from mindspore.rewrite.node.cell_container import CellContainer
|
|
32
|
+
from mindspore.rewrite.parsers.parser import Parser
|
|
33
33
|
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
|
34
|
-
from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder
|
|
34
|
+
from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder
|
|
35
|
+
from mindspore.rewrite.ast_transformers.flatten_recursive_stmt import FlattenRecursiveStmt
|
|
35
36
|
from mindspore.rewrite.ast_helpers import AstReplacer
|
|
36
|
-
from mindspore.rewrite.common.event import Event
|
|
37
37
|
from ..common import error_str
|
|
38
38
|
|
|
39
|
+
if sys.version_info >= (3, 9):
|
|
40
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
41
|
+
else:
|
|
42
|
+
import astunparse
|
|
43
|
+
|
|
39
44
|
|
|
40
45
|
class AssignParser(Parser):
|
|
41
46
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
42
47
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
super(AssignParser, self).__init__()
|
|
46
|
-
self._cell_namespce = CellNamespace('mindspore.nn')
|
|
47
|
-
self._primitive_namespce = CellNamespace('mindspore.ops.operations')
|
|
48
|
+
# Types for creating Cell Container node
|
|
49
|
+
types_for_cell_container = [SequentialCell,]
|
|
48
50
|
|
|
49
51
|
def target(self):
|
|
50
52
|
"""Parse target type."""
|
|
@@ -68,9 +70,9 @@ class AssignParser(Parser):
|
|
|
68
70
|
tuple_values = []
|
|
69
71
|
for tuple_elt in tuple_elts:
|
|
70
72
|
if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)):
|
|
71
|
-
raise RuntimeError(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
|
|
72
|
-
|
|
73
|
-
|
|
73
|
+
raise RuntimeError(error_str(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
|
|
74
|
+
f"but got ast type {type(tuple_elt).__name__}",
|
|
75
|
+
child_node=tuple_elt, father_node=node))
|
|
74
76
|
if isinstance(tuple_elt, ast.Constant):
|
|
75
77
|
tuple_values.append(tuple_elt.value)
|
|
76
78
|
elif isinstance(tuple_elt, ast.Name):
|
|
@@ -116,12 +118,12 @@ class AssignParser(Parser):
|
|
|
116
118
|
father_node=node))
|
|
117
119
|
|
|
118
120
|
@staticmethod
|
|
119
|
-
def _get_func_name(
|
|
121
|
+
def _get_func_name(ast_call: ast.Call) -> str:
|
|
120
122
|
"""
|
|
121
123
|
Get the func name from ast.Call.
|
|
122
124
|
|
|
123
125
|
Args:
|
|
124
|
-
|
|
126
|
+
ast_call (ast.Call): Input ast.Call node.
|
|
125
127
|
|
|
126
128
|
Returns:
|
|
127
129
|
Func name.
|
|
@@ -129,7 +131,7 @@ class AssignParser(Parser):
|
|
|
129
131
|
Raises:
|
|
130
132
|
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
|
|
131
133
|
"""
|
|
132
|
-
func =
|
|
134
|
+
func = ast_call.func
|
|
133
135
|
if isinstance(func, ast.Name):
|
|
134
136
|
return func.id
|
|
135
137
|
if isinstance(func, ast.Attribute):
|
|
@@ -137,15 +139,16 @@ class AssignParser(Parser):
|
|
|
137
139
|
if isinstance(func, ast.Call):
|
|
138
140
|
return AssignParser._get_func_name(func)
|
|
139
141
|
raise RuntimeError(error_str(f"funcValue should be Name or a Attribute or a Call, but got ast type "
|
|
140
|
-
f"'{type(func).__name__}'", child_node=func, father_node=
|
|
142
|
+
f"'{type(func).__name__}'", child_node=func, father_node=ast_call))
|
|
141
143
|
|
|
142
144
|
@staticmethod
|
|
143
|
-
def _get_func_scope(
|
|
145
|
+
def _get_func_scope(ast_call: ast.Call, node_manager: NodeManager = None) -> str:
|
|
144
146
|
"""
|
|
145
147
|
Get the func scope from ast.Call.
|
|
146
148
|
|
|
147
149
|
Args:
|
|
148
|
-
|
|
150
|
+
ast_call (ast.Call): Input ast.Call node.
|
|
151
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
149
152
|
|
|
150
153
|
Returns:
|
|
151
154
|
Func scope.
|
|
@@ -154,17 +157,17 @@ class AssignParser(Parser):
|
|
|
154
157
|
RuntimeError: FuncValue is not an ast.Name when func is an ast.Attribute.
|
|
155
158
|
RuntimeError: Func of input ast node is not ast.Name or ast.Attribute.
|
|
156
159
|
"""
|
|
157
|
-
func =
|
|
160
|
+
func = ast_call.func
|
|
158
161
|
if isinstance(func, ast.Name):
|
|
159
162
|
return ""
|
|
160
163
|
if isinstance(func, ast.Attribute):
|
|
161
164
|
parser = ParserRegister.instance().get_parser(type(func))
|
|
162
|
-
value = parser.process(None, func)
|
|
165
|
+
value = parser.process(None, func, node_manager)
|
|
163
166
|
return value.rsplit(".", 1)[0]
|
|
164
167
|
if isinstance(func, ast.Call):
|
|
165
|
-
return AssignParser._get_func_scope(func)
|
|
168
|
+
return AssignParser._get_func_scope(func, node_manager)
|
|
166
169
|
raise RuntimeError(error_str(f"funcValue should be Name or a Attribute or a Call, but got ast type "
|
|
167
|
-
f"'{type(func).__name__}'", child_node=func, father_node=
|
|
170
|
+
f"'{type(func).__name__}'", child_node=func, father_node=ast_call))
|
|
168
171
|
|
|
169
172
|
@staticmethod
|
|
170
173
|
def _get_symbol_object(symbol_name, origin_net):
|
|
@@ -205,9 +208,9 @@ class AssignParser(Parser):
|
|
|
205
208
|
return results
|
|
206
209
|
|
|
207
210
|
@staticmethod
|
|
208
|
-
def
|
|
211
|
+
def _get_call_instance(func_scope, func_name, stree: SymbolTree):
|
|
209
212
|
"""
|
|
210
|
-
Get
|
|
213
|
+
Get object instance from ast.Call with type of Cell or Primitive.
|
|
211
214
|
|
|
212
215
|
Args:
|
|
213
216
|
func_scope (str): Func scope.
|
|
@@ -215,21 +218,21 @@ class AssignParser(Parser):
|
|
|
215
218
|
stree (SymbolTree): Belong SymbolTree.
|
|
216
219
|
|
|
217
220
|
Returns:
|
|
218
|
-
|
|
221
|
+
An instance represents operator instance.
|
|
219
222
|
"""
|
|
220
|
-
|
|
221
223
|
if func_scope != "self":
|
|
222
|
-
|
|
223
|
-
func_name)
|
|
224
|
+
return None
|
|
224
225
|
var_dict = stree.get_origin_network().__dict__
|
|
226
|
+
# Instance is of type Cell
|
|
225
227
|
for key, value in var_dict["_cells"].items():
|
|
226
228
|
if key == func_name:
|
|
227
|
-
return
|
|
228
|
-
|
|
229
|
+
return value
|
|
230
|
+
# Instance is of type Primitive
|
|
229
231
|
for key, value in var_dict["_primitives"].items():
|
|
230
232
|
if key == func_name:
|
|
231
|
-
return
|
|
232
|
-
|
|
233
|
+
return value
|
|
234
|
+
# Instance is of other type.
|
|
235
|
+
return None
|
|
233
236
|
|
|
234
237
|
@staticmethod
|
|
235
238
|
def _get_targets(all_targets: ScopedValue) -> [Union[ScopedValue, str]]:
|
|
@@ -240,7 +243,7 @@ class AssignParser(Parser):
|
|
|
240
243
|
if not isinstance(single_target, ScopedValue) and not isinstance(single_target.value, str):
|
|
241
244
|
raise RuntimeError(f"For MindSpore Rewrite, only support str target in tuple, but got type "
|
|
242
245
|
f"{type(single_target).__name__}")
|
|
243
|
-
if single_target.type == ValueType.
|
|
246
|
+
if single_target.type == ValueType.ConstantValue and isinstance(single_target.value, str):
|
|
244
247
|
single_target.type = ValueType.NamingValue
|
|
245
248
|
targets.append(single_target)
|
|
246
249
|
else:
|
|
@@ -251,18 +254,7 @@ class AssignParser(Parser):
|
|
|
251
254
|
def _update_field_in_init(func_scope, func_name, stree: SymbolTree, sub_tree: SymbolTree) -> bool:
|
|
252
255
|
"""
|
|
253
256
|
When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method.
|
|
254
|
-
|
|
255
|
-
Update from:
|
|
256
|
-
|
|
257
|
-
.. code-block::
|
|
258
|
-
|
|
259
|
-
self.field = getattr(self._handler, "field")
|
|
260
|
-
|
|
261
|
-
to:
|
|
262
|
-
|
|
263
|
-
.. code-block::
|
|
264
|
-
|
|
265
|
-
self.field = SubNetwork(global_vars.get("field_args"))
|
|
257
|
+
Add the code like: `self.field = SubNetwork(self.field)`
|
|
266
258
|
|
|
267
259
|
Args:
|
|
268
260
|
func_scope (str): A string represents scope of function symbol.
|
|
@@ -278,39 +270,24 @@ class AssignParser(Parser):
|
|
|
278
270
|
logger.warning("Not support parse operator which is instantiated at runtime now: %s; name: %s", func_scope,
|
|
279
271
|
func_name)
|
|
280
272
|
init_func_ast = stree.get_init_func_ast()
|
|
281
|
-
|
|
282
|
-
|
|
273
|
+
sub_net_obj = sub_tree.get_origin_network()
|
|
274
|
+
sub_net_opt_name = sub_tree.get_opt_cls_name()
|
|
283
275
|
# Add .to_float(mindspore.float16) if origin subnet has this attribute
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
new_code = f"
|
|
287
|
-
|
|
288
|
-
new_code = f"
|
|
276
|
+
new_code = f"{func_scope}.{func_name} = {sub_net_opt_name}({func_scope}.{func_name})"
|
|
277
|
+
if hasattr(sub_net_obj, "fp16") and sub_net_obj.fp16:
|
|
278
|
+
new_code = f"{new_code}.to_float(mindspore.float16)"
|
|
279
|
+
elif hasattr(sub_net_obj, "bf16") and sub_net_obj.bf16:
|
|
280
|
+
new_code = f"{new_code}.to_float(mindspore.bfloat16)"
|
|
289
281
|
new_ast = ast.parse(new_code).body[0]
|
|
290
282
|
init_func_ast.body.append(new_ast)
|
|
291
|
-
return True
|
|
292
|
-
|
|
293
|
-
@staticmethod
|
|
294
|
-
def _convert_ast_binop_to_node(ast_node: ast.BinOp, father_ast_node: ast.Assign) -> Node:
|
|
295
|
-
"""convert ast.BinOp to Node"""
|
|
296
|
-
|
|
297
|
-
# only support ast.Add now
|
|
298
|
-
op = P.Add()
|
|
299
|
-
func_ast = ast.Attribute(value=ast.Name(id='F', ctx=ast.Load()), attr='add', ctx=ast.Load())
|
|
300
|
-
func = ScopedValue.create_naming_value('add', 'F')
|
|
301
|
-
|
|
302
|
-
father_ast_node.value = ast.Call(func=func_ast, args=[ast_node.left, ast_node.right], keywords=[])
|
|
303
|
-
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
|
|
304
|
-
call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args]
|
|
305
|
-
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {})
|
|
306
283
|
|
|
307
284
|
@staticmethod
|
|
308
|
-
def _create_inputs_for_cell_container(
|
|
285
|
+
def _create_inputs_for_cell_container(ast_assign) -> ['Node']:
|
|
309
286
|
"""Create inputs for cell container first node."""
|
|
310
|
-
call_ast_node =
|
|
287
|
+
call_ast_node = ast_assign.value
|
|
311
288
|
if not isinstance(call_ast_node, ast.Call):
|
|
312
289
|
raise RuntimeError(error_str(f"when creating input node for cellcontainer, value of input father ast node"
|
|
313
|
-
"is not ast.Call!'", child_node=call_ast_node, father_node=
|
|
290
|
+
"is not ast.Call!'", child_node=call_ast_node, father_node=ast_assign))
|
|
314
291
|
first_node_inputs: ['Node'] = []
|
|
315
292
|
exist_param_name = []
|
|
316
293
|
for arg in call_ast_node.args:
|
|
@@ -330,30 +307,52 @@ class AssignParser(Parser):
|
|
|
330
307
|
|
|
331
308
|
if call_ast_node.keywords:
|
|
332
309
|
raise RuntimeError(error_str(f"Not support keyword input for cellcontainer now.",
|
|
333
|
-
child_node=call_ast_node, father_node=
|
|
310
|
+
child_node=call_ast_node, father_node=ast_assign))
|
|
334
311
|
|
|
335
312
|
return first_node_inputs
|
|
336
313
|
|
|
337
|
-
|
|
314
|
+
@staticmethod
|
|
315
|
+
def _update_cell_container_in_init(stree, container_name, container_idx, subnet_opt_name):
|
|
316
|
+
"""
|
|
317
|
+
When nn.SequentialCell include sub-symboltree, the new class definition will be used to create object.
|
|
318
|
+
So the assign code will be got from origin code first, and then be modified to new class name.
|
|
319
|
+
|
|
320
|
+
Codes like:
|
|
321
|
+
|
|
322
|
+
`self.container = nn.SequentialCell([ReLU(), MyNet()])`
|
|
323
|
+
|
|
324
|
+
will be updated by add codes:
|
|
325
|
+
|
|
326
|
+
`self.container[1] = MyNetOpt(self.container[1])`
|
|
327
|
+
|
|
328
|
+
"""
|
|
329
|
+
new_code = f"{container_name}[{container_idx}] = {subnet_opt_name}({container_name}[{container_idx}])"
|
|
330
|
+
new_ast = ast.parse(new_code).body[0]
|
|
331
|
+
stree.get_init_func_ast().body.append(new_ast)
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def cell_container_process(ast_assign, stree, targets, func_scope_name, call_args, call_kwargs,
|
|
335
|
+
op_name, container_obj):
|
|
338
336
|
""" parse cell container object."""
|
|
339
|
-
cell_container = CellContainer(
|
|
340
|
-
|
|
341
|
-
first_node_inputs = AssignParser._create_inputs_for_cell_container(
|
|
337
|
+
cell_container = CellContainer(ast_assign, targets, func_scope_name, call_args, call_kwargs,
|
|
338
|
+
op_name, stree, container_obj)
|
|
339
|
+
first_node_inputs = AssignParser._create_inputs_for_cell_container(ast_assign)
|
|
342
340
|
for i, cell in enumerate(container_obj):
|
|
343
|
-
|
|
341
|
+
cell_name = type(cell).__name__
|
|
342
|
+
is_sub_tree = is_subtree(cell)
|
|
344
343
|
if is_sub_tree:
|
|
345
344
|
stb = SymbolTreeBuilder(cell)
|
|
346
345
|
new_stree = stb.build()
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
type(cell).__name__, cell)
|
|
346
|
+
sub_node = TreeNode.create_tree_node(new_stree, None, targets, cell_name, call_args,
|
|
347
|
+
call_kwargs, cell_name, cell)
|
|
348
|
+
AssignParser._update_cell_container_in_init(stree, func_scope_name, i, new_stree.get_opt_cls_name())
|
|
351
349
|
else:
|
|
352
|
-
sub_node = Node.create_call_buildin_op(cell,
|
|
353
|
-
|
|
350
|
+
sub_node = Node.create_call_buildin_op(cell, None, targets, cell_name, call_args,
|
|
351
|
+
call_kwargs, cell_name)
|
|
354
352
|
# add sub node to cell_container
|
|
355
|
-
cell_container.append(sub_node)
|
|
356
|
-
# set node inputs
|
|
353
|
+
cell_container.append(sub_node, False)
|
|
354
|
+
# set node inputs, those input nodes are NOT inserted in container, only
|
|
355
|
+
# topological relationship is updated.
|
|
357
356
|
if i == 0:
|
|
358
357
|
for idx, arg_provider in enumerate(first_node_inputs):
|
|
359
358
|
sub_node.set_arg_providers(idx, (arg_provider, 0))
|
|
@@ -361,43 +360,61 @@ class AssignParser(Parser):
|
|
|
361
360
|
sub_node.set_arg_providers(0, (cell_container.node_list[i-1], 0))
|
|
362
361
|
return cell_container
|
|
363
362
|
|
|
364
|
-
|
|
365
|
-
|
|
363
|
+
@staticmethod
|
|
364
|
+
def process_external_function(stree, func_name, file_path):
|
|
365
|
+
"""
|
|
366
|
+
Process external function.
|
|
367
|
+
Ast of external function defined in specifical file_path will be saved to generate codes.
|
|
368
|
+
"""
|
|
366
369
|
for k, m in sys.modules.items():
|
|
367
370
|
if k in ("_ast", "ast"):
|
|
368
371
|
continue
|
|
369
372
|
if hasattr(m, func_name):
|
|
370
373
|
func = getattr(m, func_name)
|
|
374
|
+
if not inspect.isfunction(func):
|
|
375
|
+
continue
|
|
376
|
+
func_source_code_file = inspect.getfile(func)
|
|
377
|
+
if func_source_code_file != file_path:
|
|
378
|
+
continue
|
|
371
379
|
source_code = inspect.getsource(func)
|
|
372
380
|
ast_root: ast.Module = ast.parse(source_code)
|
|
373
|
-
stree.
|
|
381
|
+
stree.get_external_ast().append(ast_root.body[0])
|
|
374
382
|
return func, ast_root.body[0]
|
|
375
|
-
|
|
383
|
+
logger.info(f"Cannot get ast of function {func_name} from {file_path}.")
|
|
384
|
+
return None, None
|
|
376
385
|
|
|
377
386
|
def _process_internal_function(self, stree: SymbolTree, func_name):
|
|
378
387
|
"""Process internal function."""
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
for body in stree.
|
|
388
|
+
func_inst = getattr(stree.get_origin_network(), func_name)
|
|
389
|
+
ast_functiondef = None
|
|
390
|
+
for body in stree.get_class_ast().body:
|
|
382
391
|
if isinstance(body, ast.FunctionDef) and func_name == body.name:
|
|
383
|
-
|
|
384
|
-
return
|
|
385
|
-
|
|
386
|
-
def
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
392
|
+
ast_functiondef = body
|
|
393
|
+
return func_inst, ast_functiondef
|
|
394
|
+
|
|
395
|
+
def _create_callfunction_node(self, targets: [ScopedValue], func_scope_name: ScopedValue, args: [ScopedValue],
|
|
396
|
+
kwargs: {str: ScopedValue}, node_name: str, ast_assign: ast.Assign,
|
|
397
|
+
ast_functiondef: ast.FunctionDef, stree: SymbolTree, instance):
|
|
398
|
+
"""Create CallFunction node for class internal function."""
|
|
399
|
+
node = CallFunction(targets, func_scope_name, args, kwargs, node_name, ast_assign, ast_functiondef,
|
|
400
|
+
stree, instance)
|
|
401
|
+
# expand ast codes
|
|
402
|
+
ast_functiondef = FlattenRecursiveStmt().transform(ast_functiondef, [func_scope_name.value], stree)
|
|
403
|
+
# parse ast codes into CallFunction Node
|
|
404
|
+
parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
405
|
+
parser.process(stree, ast_functiondef, node_manager=node)
|
|
406
|
+
return node
|
|
407
|
+
|
|
408
|
+
def _convert_ast_call_to_node(self, ast_call: ast.Call, ast_assign: ast.Assign, stree: SymbolTree,
|
|
409
|
+
node_manager: NodeManager) -> Node:
|
|
394
410
|
"""
|
|
395
411
|
Convert ast.Call to a symbol tree node.
|
|
396
412
|
|
|
397
413
|
Args:
|
|
398
|
-
|
|
399
|
-
|
|
414
|
+
ast_call (ast.Call): An ast.Call of assign node in construct.
|
|
415
|
+
ast_assign (ast.Assign): Assign node in construct.
|
|
400
416
|
stree (SymbolTree): Symbol Tree under parsing.
|
|
417
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
401
418
|
|
|
402
419
|
Returns:
|
|
403
420
|
An instance of Node in Symbol Tree.
|
|
@@ -405,86 +422,63 @@ class AssignParser(Parser):
|
|
|
405
422
|
Raises:
|
|
406
423
|
RuntimeError: If operator instance invoked by assign is undefined.
|
|
407
424
|
"""
|
|
408
|
-
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(
|
|
409
|
-
func_name = AssignParser._get_func_name(
|
|
425
|
+
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(ast_assign.targets[0]))
|
|
426
|
+
func_name = AssignParser._get_func_name(ast_call)
|
|
410
427
|
if func_name is None or func_name == "":
|
|
411
428
|
raise RuntimeError("function name not exist")
|
|
412
|
-
func_scope = AssignParser._get_func_scope(
|
|
413
|
-
|
|
414
|
-
call_args = [AssignParser._create_scopedvalue(arg) for arg in
|
|
415
|
-
call_kwargs = AssignParser._create_kwargs(
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
if
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
+
func_scope = AssignParser._get_func_scope(ast_call, node_manager)
|
|
430
|
+
func_scope_name = ScopedValue.create_naming_value(func_name, func_scope)
|
|
431
|
+
call_args = [AssignParser._create_scopedvalue(arg) for arg in ast_call.args]
|
|
432
|
+
call_kwargs = AssignParser._create_kwargs(ast_call.keywords)
|
|
433
|
+
|
|
434
|
+
func_inst = AssignParser._get_call_instance(func_scope, func_name, stree)
|
|
435
|
+
if func_inst is None:
|
|
436
|
+
# Function is not Cell and Primitive
|
|
437
|
+
if func_scope in ('self', stree.get_opt_cls_name()) and hasattr(stree.get_origin_network(), func_name):
|
|
438
|
+
# Function defined in current class
|
|
439
|
+
func_inst, ast_functiondef = self._process_internal_function(stree, func_name)
|
|
440
|
+
if ast_functiondef is None:
|
|
441
|
+
raise RuntimeError(f"Find ast of function {func_scope}.{func_name} in symbol tree class failed.")
|
|
442
|
+
node = self._create_callfunction_node(targets, func_scope_name, call_args, call_kwargs, func_name,
|
|
443
|
+
ast_assign, ast_functiondef, stree, func_inst)
|
|
444
|
+
elif is_functional(func_name):
|
|
445
|
+
# Function defined in mindspore.ops.functional
|
|
446
|
+
parser = ParserRegister.instance().get_parser(type(ast_call.func)) # ast.Name or ast.Attribute
|
|
447
|
+
func_name = parser.process(stree, ast_call.func, node_manager).split(".")[-1]
|
|
448
|
+
func_inst = get_functional(func_name)
|
|
449
|
+
node = Node.inner_create_call_function(func_name, ast_assign, func_name, func_inst, targets,
|
|
450
|
+
call_args, call_kwargs)
|
|
429
451
|
else:
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
452
|
+
origin_net_file = inspect.getfile(type(stree.get_origin_network()))
|
|
453
|
+
if not os.path.exists(origin_net_file):
|
|
454
|
+
raise RuntimeError(f"For MindSpore Rewrite, in assign parser, origin_net_file "
|
|
455
|
+
f"{origin_net_file} not exist")
|
|
456
|
+
func_inst, ast_functiondef = AssignParser.process_external_function(stree, func_name, origin_net_file)
|
|
457
|
+
node = Node.inner_create_call_function(func_name, ast_assign, func_name, func_inst, targets,
|
|
458
|
+
call_args, call_kwargs)
|
|
433
459
|
return node
|
|
434
|
-
if isinstance(
|
|
435
|
-
node =
|
|
436
|
-
|
|
460
|
+
if isinstance(func_inst, tuple(AssignParser.types_for_cell_container)):
|
|
461
|
+
node = AssignParser.cell_container_process(ast_assign, stree, targets, func_scope_name, call_args,
|
|
462
|
+
call_kwargs, func_name, func_inst)
|
|
437
463
|
return node
|
|
438
|
-
if isinstance(
|
|
439
|
-
return Node.create_call_buildin_op(
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
if
|
|
443
|
-
|
|
464
|
+
if isinstance(func_inst, Primitive):
|
|
465
|
+
return Node.create_call_buildin_op(func_inst, ast_assign, targets, func_scope_name, call_args, call_kwargs,
|
|
466
|
+
func_name)
|
|
467
|
+
if isinstance(func_inst, Cell):
|
|
468
|
+
if is_subtree(func_inst):
|
|
469
|
+
# Instance of function is user custom network, create sub-symboltree
|
|
470
|
+
stb = SymbolTreeBuilder(func_inst)
|
|
444
471
|
new_stree = stb.build()
|
|
445
|
-
|
|
446
|
-
if changed:
|
|
447
|
-
# class SubSubNet:
|
|
448
|
-
# def __init__(self, global_vars):
|
|
449
|
-
# self._handler = global_vars.get("handler")
|
|
450
|
-
#
|
|
451
|
-
# class SubNet:
|
|
452
|
-
# def __init__(self, global_vars):
|
|
453
|
-
# self._handler = global_vars.get("handler")
|
|
454
|
-
# self._subsubnet = None
|
|
455
|
-
# if xxx:
|
|
456
|
-
# self._subsubnet = SubSubNet(xxx)
|
|
457
|
-
#
|
|
458
|
-
# Assuming there are two instance of SubNet A and B. "if xxx" in A is True, and in B is False.
|
|
459
|
-
# So self._subsubnet in A is an instance of SubSubNet, and in B is None.
|
|
460
|
-
# So After rewrite, A's code:
|
|
461
|
-
# class SubNetA:
|
|
462
|
-
# def __init__(self, global_vars):
|
|
463
|
-
# self._handler = global_vars.get("handler")
|
|
464
|
-
# self._subsubnet = SubSubNet(global_vars.get("subsubnet_args"))
|
|
465
|
-
# while B's code:
|
|
466
|
-
# class SubNetB:
|
|
467
|
-
# def __init__(self, global_vars):
|
|
468
|
-
# self._handler = global_vars.get("handler")
|
|
469
|
-
# self._subsubnet = getattr(self._handler, "_subsubnet")
|
|
470
|
-
# So SubNet should use SubNetA as its code when _update_field_in_init return True.
|
|
471
|
-
# So SubNet should use SubNetB as its code when _update_field_in_init return False or undefined
|
|
472
|
-
# error will occur to "global_vars.get("subsubnet_args")".
|
|
473
|
-
stree.on_change(Event.CodeChangeEvent)
|
|
474
|
-
# Sub-network in main-network is expressed as:
|
|
475
|
-
# self._subnet = SubNet(global_vars.get("subnet_args"))
|
|
476
|
-
# when subnet is changed, its class will change, take SubNet1 as new class-name, so code main-network
|
|
477
|
-
# also need to change:
|
|
478
|
-
# self._subnet = SubNet1(global_vars.get("subnet_args"))
|
|
479
|
-
# so a change in sub-network should also be identified as a change in main-network.
|
|
480
|
-
# so main-network should observe sub-network
|
|
472
|
+
AssignParser._update_field_in_init(func_scope, func_name, stree, new_stree)
|
|
481
473
|
replacer = AstReplacer(new_stree.get_class_ast())
|
|
482
474
|
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
|
483
|
-
return TreeNode.create_tree_node(new_stree,
|
|
484
|
-
func_name, new_stree.get_origin_network())
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
475
|
+
return TreeNode.create_tree_node(new_stree, ast_assign, targets, func_scope_name, call_args,
|
|
476
|
+
call_kwargs, func_name, new_stree.get_origin_network())
|
|
477
|
+
# Instance of function is buildin cells
|
|
478
|
+
return Node.create_call_buildin_op(func_inst, ast_assign, targets, func_scope_name, call_args, call_kwargs,
|
|
479
|
+
func_name)
|
|
480
|
+
raise RuntimeError("For MindSpore Rewrite, unsupported operation in ast.Call found: ",
|
|
481
|
+
type(func_inst).__name__)
|
|
488
482
|
|
|
489
483
|
@staticmethod
|
|
490
484
|
def _tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool:
|
|
@@ -499,62 +493,62 @@ class AssignParser(Parser):
|
|
|
499
493
|
return True
|
|
500
494
|
|
|
501
495
|
@staticmethod
|
|
502
|
-
def _convert_ast_mathops_to_node(
|
|
503
|
-
|
|
496
|
+
def _convert_ast_mathops_to_node(ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
|
|
497
|
+
ast_assign: ast.Assign) -> Node:
|
|
504
498
|
"""
|
|
505
499
|
Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
|
|
506
500
|
a symbol tree node.
|
|
507
501
|
|
|
508
502
|
Args:
|
|
509
|
-
|
|
503
|
+
ast_op (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
|
|
510
504
|
operation in construct function.
|
|
511
|
-
|
|
505
|
+
ast_assign (ast.Assign): Assign node in construct.
|
|
512
506
|
|
|
513
507
|
Returns:
|
|
514
508
|
An instance of Node in Symbol Tree.
|
|
515
509
|
|
|
516
510
|
Raises:
|
|
517
|
-
TypeError: The type of parameter '
|
|
511
|
+
TypeError: The type of parameter 'ast_op' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
|
|
518
512
|
|
|
519
513
|
"""
|
|
520
|
-
if not isinstance(
|
|
521
|
-
raise TypeError("The type of parameter '
|
|
522
|
-
"ast.BoolOp, ast.Compare), but got ", type(
|
|
514
|
+
if not isinstance(ast_op, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
515
|
+
raise TypeError("The type of parameter 'ast_op' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
516
|
+
"ast.BoolOp, ast.Compare), but got ", type(ast_op))
|
|
523
517
|
|
|
524
|
-
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(
|
|
518
|
+
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(ast_assign.targets[0]))
|
|
525
519
|
args = []
|
|
526
|
-
op_type_str = type(
|
|
520
|
+
op_type_str = type(ast_op).__name__
|
|
527
521
|
op_type = ScopedValue.create_naming_value(op_type_str)
|
|
528
522
|
ops = {}
|
|
529
523
|
name = op_type_str
|
|
530
|
-
if isinstance(
|
|
531
|
-
op = type(
|
|
524
|
+
if isinstance(ast_op, ast.BinOp):
|
|
525
|
+
op = type(ast_op.op).__name__
|
|
532
526
|
name = f'{name}_{op}'
|
|
533
527
|
ops['0'] = ScopedValue.create_naming_value(op)
|
|
534
|
-
args.append(AssignParser._create_scopedvalue(
|
|
535
|
-
args.append(AssignParser._create_scopedvalue(
|
|
536
|
-
elif isinstance(
|
|
537
|
-
op = type(
|
|
528
|
+
args.append(AssignParser._create_scopedvalue(ast_op.left))
|
|
529
|
+
args.append(AssignParser._create_scopedvalue(ast_op.right))
|
|
530
|
+
elif isinstance(ast_op, ast.UnaryOp):
|
|
531
|
+
op = type(ast_op.op).__name__
|
|
538
532
|
name = f'{name}_{op}'
|
|
539
533
|
ops['0'] = ScopedValue.create_naming_value(op)
|
|
540
|
-
args.append(AssignParser._create_scopedvalue(
|
|
541
|
-
elif isinstance(
|
|
542
|
-
op = type(
|
|
534
|
+
args.append(AssignParser._create_scopedvalue(ast_op.operand))
|
|
535
|
+
elif isinstance(ast_op, ast.BoolOp):
|
|
536
|
+
op = type(ast_op.op).__name__
|
|
543
537
|
name = f'{name}_{op}'
|
|
544
538
|
ops['0'] = ScopedValue.create_naming_value(op)
|
|
545
|
-
for value in
|
|
539
|
+
for value in ast_op.values:
|
|
546
540
|
args.append(AssignParser._create_scopedvalue(value))
|
|
547
|
-
elif isinstance(
|
|
548
|
-
args.append(AssignParser._create_scopedvalue(
|
|
549
|
-
for idx,
|
|
550
|
-
op = type(
|
|
541
|
+
elif isinstance(ast_op, ast.Compare):
|
|
542
|
+
args.append(AssignParser._create_scopedvalue(ast_op.left))
|
|
543
|
+
for idx, ast_cmp_op in enumerate(ast_op.ops):
|
|
544
|
+
op = type(ast_cmp_op).__name__
|
|
551
545
|
name = f'{name}_{op}'
|
|
552
546
|
ops[str(idx)] = ScopedValue.create_naming_value(op)
|
|
553
|
-
args.append(AssignParser._create_scopedvalue(
|
|
547
|
+
args.append(AssignParser._create_scopedvalue(ast_op.comparators[idx]))
|
|
554
548
|
name = name.lower()
|
|
555
|
-
return Node.create_mathops_node(
|
|
549
|
+
return Node.create_mathops_node(ast_assign, targets, op_type, args, ops, name)
|
|
556
550
|
|
|
557
|
-
def process(self, stree: SymbolTree, node: ast.Assign):
|
|
551
|
+
def process(self, stree: SymbolTree, node: ast.Assign, node_manager: NodeManager):
|
|
558
552
|
"""
|
|
559
553
|
Parse ast.Assign and create a node in symbol tree.
|
|
560
554
|
|
|
@@ -566,6 +560,7 @@ class AssignParser(Parser):
|
|
|
566
560
|
Args:
|
|
567
561
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
568
562
|
node ([ast.Assign]): An ast.Assign node.
|
|
563
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
569
564
|
|
|
570
565
|
Raises:
|
|
571
566
|
RuntimeError: Only support one target in assign now.
|
|
@@ -576,18 +571,18 @@ class AssignParser(Parser):
|
|
|
576
571
|
try:
|
|
577
572
|
if len(targets) != 1:
|
|
578
573
|
raise RuntimeError(
|
|
579
|
-
error_str(f"only support one target in assign now.",
|
|
574
|
+
error_str(f"only support one target in assign now.", targets, node))
|
|
580
575
|
value = node.value
|
|
581
576
|
if isinstance(value, ast.Call):
|
|
582
|
-
node_ = self._convert_ast_call_to_node(value, node, stree)
|
|
583
|
-
stree.append_origin_field(node_)
|
|
577
|
+
node_ = self._convert_ast_call_to_node(value, node, stree, node_manager)
|
|
578
|
+
stree.append_origin_field(node_, node_manager)
|
|
584
579
|
elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
585
580
|
node_ = AssignParser._convert_ast_mathops_to_node(value, node)
|
|
586
|
-
stree.append_origin_field(node_)
|
|
581
|
+
stree.append_origin_field(node_, node_manager)
|
|
587
582
|
elif isinstance(value, ast.Subscript):
|
|
588
583
|
logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
|
|
589
584
|
f"ignored as a python node now")
|
|
590
|
-
stree.try_append_python_node(node, node)
|
|
585
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
591
586
|
elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant,
|
|
592
587
|
ast.Bytes, ast.Str)):
|
|
593
588
|
if isinstance(value, ast.Name):
|
|
@@ -601,7 +596,7 @@ class AssignParser(Parser):
|
|
|
601
596
|
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
|
602
597
|
call_args = [AssignParser._create_scopedvalue(value)]
|
|
603
598
|
node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
|
|
604
|
-
stree.append_origin_field(node_)
|
|
599
|
+
stree.append_origin_field(node_, node_manager)
|
|
605
600
|
elif isinstance(value, ast.Tuple):
|
|
606
601
|
if AssignParser._tuple_elts_support_scopledvalue(value):
|
|
607
602
|
# ensure that each element's type in tuple is supported by scopled value
|
|
@@ -611,14 +606,14 @@ class AssignParser(Parser):
|
|
|
611
606
|
args.append(AssignParser._create_scopedvalue(elt))
|
|
612
607
|
node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"),
|
|
613
608
|
args, {}, "tuple")
|
|
614
|
-
stree.append_origin_field(node_)
|
|
609
|
+
stree.append_origin_field(node_, node_manager)
|
|
615
610
|
else:
|
|
616
|
-
logger.
|
|
617
|
-
|
|
618
|
-
stree.try_append_python_node(node, node)
|
|
611
|
+
logger.info(f"some elements in Tuple of assign({astunparse.unparse(node)}) are not supported "
|
|
612
|
+
"in rewrite, fallback to python")
|
|
613
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
619
614
|
elif isinstance(value, (ast.List, ast.Dict)):
|
|
620
615
|
# add these as callmethod node if necessary
|
|
621
|
-
stree.try_append_python_node(node, node)
|
|
616
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
622
617
|
else:
|
|
623
618
|
raise RuntimeError(
|
|
624
619
|
error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
|
|
@@ -626,8 +621,8 @@ class AssignParser(Parser):
|
|
|
626
621
|
f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
|
|
627
622
|
child_node=value, father_node=node))
|
|
628
623
|
except RuntimeError:
|
|
629
|
-
logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
|
|
630
|
-
stree.try_append_python_node(node, node)
|
|
624
|
+
logger.info(f"ops-call({astunparse.unparse(node).strip()}) not supported in rewrite, fallback to python")
|
|
625
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
631
626
|
|
|
632
627
|
|
|
633
628
|
g_assign_parser = reg_parser(AssignParser())
|