mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.10__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -15,13 +15,11 @@
|
|
|
15
15
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
16
|
import ast
|
|
17
17
|
|
|
18
|
-
from mindspore.rewrite.parser import Parser
|
|
18
|
+
from mindspore.rewrite.parsers.parser import Parser
|
|
19
19
|
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
20
|
-
from mindspore.rewrite.parser_register import ParserRegister
|
|
21
|
-
|
|
22
|
-
from mindspore.rewrite.parser_register import reg_parser
|
|
20
|
+
from mindspore.rewrite.parsers.parser_register import ParserRegister, reg_parser
|
|
23
21
|
from ..common import error_str
|
|
24
|
-
|
|
22
|
+
from ..node.node_manager import NodeManager
|
|
25
23
|
|
|
26
24
|
class AttributeParser(Parser):
|
|
27
25
|
"""Parse ast.Attribute in construct function to node of SymbolTree."""
|
|
@@ -30,13 +28,14 @@ class AttributeParser(Parser):
|
|
|
30
28
|
"""Parse target type."""
|
|
31
29
|
return ast.Attribute
|
|
32
30
|
|
|
33
|
-
def process(self, stree: SymbolTree, node: ast.Attribute):
|
|
31
|
+
def process(self, stree: SymbolTree, node: ast.Attribute, node_manager: NodeManager):
|
|
34
32
|
"""
|
|
35
33
|
Parse ast.Attribute node.
|
|
36
34
|
|
|
37
35
|
Args:
|
|
38
36
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
39
37
|
node ([ast.Attribute]): An ast.Attribute node.
|
|
38
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
40
39
|
|
|
41
40
|
Returns:
|
|
42
41
|
The value of node.
|
|
@@ -47,8 +46,11 @@ class AttributeParser(Parser):
|
|
|
47
46
|
if not isinstance(node, ast.Attribute):
|
|
48
47
|
raise TypeError(error_str(f"Attribute parser only supports parsing ast.Attribute type nodes, but got "
|
|
49
48
|
f"'{type(node).__name__}'", father_node=node))
|
|
49
|
+
if not isinstance(node.value, (ast.Name, ast.Attribute)):
|
|
50
|
+
raise RuntimeError(error_str(f"Attribute parser only supports (ast.Attribute, ast.Name) as value of "
|
|
51
|
+
f"ast.Attribute, but got '{type(node).__name__}'", father_node=node))
|
|
50
52
|
parser = ParserRegister.instance().get_parser(type(node.value))
|
|
51
|
-
value = parser.process(stree, node.value)
|
|
53
|
+
value = parser.process(stree, node.value, node_manager)
|
|
52
54
|
|
|
53
55
|
return ".".join([value, node.attr])
|
|
54
56
|
|
|
@@ -13,17 +13,19 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
|
|
16
|
-
import sys
|
|
17
|
-
import ast
|
|
18
16
|
import inspect
|
|
17
|
+
from typing import Union, Dict
|
|
18
|
+
import ast
|
|
19
19
|
from mindspore import log as logger
|
|
20
20
|
from mindspore.nn import Cell
|
|
21
21
|
from mindspore._extends.parse.namespace import CellNamespace
|
|
22
22
|
from ..symbol_tree import SymbolTree
|
|
23
|
-
from
|
|
24
|
-
from
|
|
23
|
+
from .parser import Parser
|
|
24
|
+
from .parser_register import ParserRegister, reg_parser
|
|
25
25
|
from ..ast_helpers import AstReplacer
|
|
26
26
|
from ..common import error_str
|
|
27
|
+
from ..parsers.module_parser import ModuleParser
|
|
28
|
+
from ..node.node_manager import NodeManager
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
class AstScopeChecker:
|
|
@@ -106,25 +108,58 @@ class AstScopeChecker:
|
|
|
106
108
|
class ClassDefParser(Parser):
|
|
107
109
|
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
|
|
108
110
|
|
|
111
|
+
# a denied_function_decorator_list which is registered by user
|
|
112
|
+
denied_function_decorator_list = []
|
|
113
|
+
# Entry function of the forward computation process
|
|
114
|
+
entry_function = "construct"
|
|
115
|
+
|
|
109
116
|
def __init__(self):
|
|
110
117
|
"""Constructor"""
|
|
111
118
|
super(ClassDefParser, self).__init__()
|
|
112
119
|
self._cell_namespace = CellNamespace('mindspore.nn')
|
|
113
120
|
|
|
114
121
|
@staticmethod
|
|
115
|
-
def
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
122
|
+
def _process_init_func_ast(init_ast: ast.FunctionDef, class_name: str, is_father_class: bool,
|
|
123
|
+
father_classes: dict):
|
|
124
|
+
"""Process init func"""
|
|
125
|
+
ClassDefParser._modify_arguments_of_init_func(init_ast)
|
|
126
|
+
new_bodies = ClassDefParser._create_bodys_of_init_func(class_name, is_father_class, father_classes)
|
|
127
|
+
init_ast.body = new_bodies
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
def _create_bodys_of_init_func(class_name: str, is_father_class: bool, father_classes: dict):
|
|
131
|
+
"""Modify bodys of init func."""
|
|
132
|
+
new_bodies = []
|
|
133
|
+
# update father class init in new class
|
|
134
|
+
father_class_init_bodies = ClassDefParser._father_class_init_process(father_classes, is_father_class)
|
|
135
|
+
new_bodies.extend(father_class_init_bodies)
|
|
136
|
+
# copy variables into new class
|
|
137
|
+
if is_father_class:
|
|
138
|
+
ast_copy_attr = ast.parse(
|
|
139
|
+
"for key, value in obj.__dict__.items():\n"
|
|
140
|
+
" if not key.startswith('__'):\n"
|
|
141
|
+
f" setattr({class_name}, key, value)").body[0]
|
|
142
|
+
new_bodies.append(ast_copy_attr)
|
|
143
|
+
else:
|
|
144
|
+
ast_copy_attr = ast.parse(
|
|
145
|
+
"for key, value in obj.__dict__.items(): setattr(self, key, value)").body[0]
|
|
146
|
+
new_bodies.append(ast_copy_attr)
|
|
147
|
+
return new_bodies
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def _father_class_init_process(father_classes: dict, is_father_class: bool) -> [ast.AST]:
|
|
151
|
+
"""Add ast bodies of code: father_class.__init__(...)"""
|
|
152
|
+
father_class_init_bodies = []
|
|
153
|
+
for idx, father_class in father_classes.items():
|
|
154
|
+
if father_class == "Cell":
|
|
155
|
+
father_class_init_code = "super().__init__()"
|
|
156
|
+
elif is_father_class:
|
|
157
|
+
father_class_init_code = f"{father_class}.__init__(self, obj.__bases__[{idx}])"
|
|
158
|
+
else:
|
|
159
|
+
father_class_init_code = f"{father_class}.__init__(self, obj.__class__.__bases__[{idx}])"
|
|
160
|
+
father_class_init_ast = ast.parse(father_class_init_code).body[0]
|
|
161
|
+
father_class_init_bodies.append(father_class_init_ast)
|
|
162
|
+
return father_class_init_bodies
|
|
128
163
|
|
|
129
164
|
@staticmethod
|
|
130
165
|
def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
|
|
@@ -135,127 +170,153 @@ class ClassDefParser(Parser):
|
|
|
135
170
|
kw_defaults=[], defaults=[], vararg=None, kwarg=None)
|
|
136
171
|
ast.fix_missing_locations(ast_init_fn)
|
|
137
172
|
|
|
173
|
+
@staticmethod
|
|
174
|
+
def get_ast_name(ast_node: Union[ast.Name, ast.Attribute]) -> str:
|
|
175
|
+
"""Get ast id name"""
|
|
176
|
+
if isinstance(ast_node, ast.Name):
|
|
177
|
+
return ast_node.id
|
|
178
|
+
if isinstance(ast_node, ast.Attribute):
|
|
179
|
+
return ast_node.attr
|
|
180
|
+
return ""
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def _process_class_variables(stree: SymbolTree, function_defs: list):
|
|
184
|
+
"""Process class variables of class, only used in child class."""
|
|
185
|
+
init_func_ast = stree.get_init_func_ast()
|
|
186
|
+
for key, value in stree.get_origin_network().__class__.__dict__.items():
|
|
187
|
+
if key.startswith('__'):
|
|
188
|
+
# ignore inner functions
|
|
189
|
+
continue
|
|
190
|
+
if callable(value) and key in function_defs:
|
|
191
|
+
# ignore functions defined by self
|
|
192
|
+
continue
|
|
193
|
+
assign_code = f"self.__class__.{key} = obj.__class__.{key}"
|
|
194
|
+
assign_ast = ast.parse(assign_code).body[0]
|
|
195
|
+
init_func_ast.body.append(assign_ast)
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def _need_add_init_func(cls_ast: ast.ClassDef) -> bool:
|
|
199
|
+
"""If the class don't have init func, we need to add an init func"""
|
|
200
|
+
for body in cls_ast.body:
|
|
201
|
+
if isinstance(body, ast.FunctionDef) and body.name == '__init__':
|
|
202
|
+
return False
|
|
203
|
+
return True
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def _add_init_func(cls_ast: ast.ClassDef):
|
|
207
|
+
"""Add init func with super().__init__()"""
|
|
208
|
+
init_func_ast = ast.parse("def __init__(self): super().__init__()").body[0]
|
|
209
|
+
cls_ast.body.insert(0, init_func_ast)
|
|
210
|
+
ast.fix_missing_locations(cls_ast)
|
|
211
|
+
|
|
212
|
+
@staticmethod
|
|
213
|
+
def _process_father_classes(stree, node: ast.ClassDef, cur_class_def: type) -> list:
|
|
214
|
+
"""Process father class."""
|
|
215
|
+
father_classes: Dict[int, str] = {}
|
|
216
|
+
for idx, base in enumerate(node.bases):
|
|
217
|
+
father_class_name = ClassDefParser.get_ast_name(base)
|
|
218
|
+
if not father_class_name:
|
|
219
|
+
continue
|
|
220
|
+
father_classes[idx] = father_class_name
|
|
221
|
+
if father_class_name == "Cell":
|
|
222
|
+
continue
|
|
223
|
+
father_class_def = cur_class_def.__bases__[idx]
|
|
224
|
+
ClassDefParser._process_one_father_class(stree, father_class_def, father_class_name)
|
|
225
|
+
node.bases[idx] = ast.Name(id=father_class_name, ctx=ast.Load())
|
|
226
|
+
return father_classes
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def _process_one_father_class(stree: SymbolTree, father_class_def: type, father_class_name: str):
|
|
230
|
+
"""Process one father class"""
|
|
231
|
+
# save father class's file path and imports into symbol tree
|
|
232
|
+
net_path = inspect.getfile(father_class_def)
|
|
233
|
+
ModuleParser.save_file_path_to_sys(stree, 0, net_path)
|
|
234
|
+
ModuleParser.save_imports_from_file(stree, net_path)
|
|
235
|
+
# get father class's ast
|
|
236
|
+
source_code = inspect.getsource(father_class_def)
|
|
237
|
+
father_class_ast: ast.ClassDef = ast.parse(source_code).body[0]
|
|
238
|
+
# process father class's father classes
|
|
239
|
+
father_classes = ClassDefParser._process_father_classes(stree, father_class_ast, father_class_def)
|
|
240
|
+
# process father class's __init__ function
|
|
241
|
+
if ClassDefParser._need_add_init_func(father_class_ast):
|
|
242
|
+
ClassDefParser._add_init_func(father_class_ast)
|
|
243
|
+
for body in father_class_ast.body[:]:
|
|
244
|
+
if isinstance(body, ast.FunctionDef) and body.name == "__init__":
|
|
245
|
+
# Add function decorator
|
|
246
|
+
ClassDefParser._func_decorator_process(body)
|
|
247
|
+
ClassDefParser._process_init_func_ast(body, father_class_name, True, father_classes)
|
|
248
|
+
else:
|
|
249
|
+
# Remove other codes, which are copied in __init__ function.
|
|
250
|
+
father_class_ast.body.remove(body)
|
|
251
|
+
# save father class's ast into symbol tree
|
|
252
|
+
stree.get_father_class_ast().append(father_class_ast)
|
|
253
|
+
|
|
254
|
+
@staticmethod
|
|
255
|
+
def _func_decorator_process(node: ast.FunctionDef):
|
|
256
|
+
"""
|
|
257
|
+
User should set the denied function decorators,
|
|
258
|
+
because the symbol_tree cant pass the correct parameters to decorators but the instance "obj".
|
|
259
|
+
"""
|
|
260
|
+
for decorator in node.decorator_list[:]:
|
|
261
|
+
decorator_name = ""
|
|
262
|
+
if isinstance(decorator, ast.Call):
|
|
263
|
+
func = decorator.func
|
|
264
|
+
if isinstance(func, ast.Name):
|
|
265
|
+
decorator_name = func.id
|
|
266
|
+
elif isinstance(decorator, ast.Name):
|
|
267
|
+
decorator_name = decorator.id
|
|
268
|
+
if decorator_name in ClassDefParser.denied_function_decorator_list:
|
|
269
|
+
node.decorator_list.remove(decorator)
|
|
270
|
+
|
|
138
271
|
def target(self):
|
|
139
272
|
"""Parse target type"""
|
|
140
273
|
return ast.ClassDef
|
|
141
274
|
|
|
142
|
-
def process(self, stree: SymbolTree, node: ast.ClassDef):
|
|
275
|
+
def process(self, stree: SymbolTree, node: ast.ClassDef, node_manager: NodeManager):
|
|
143
276
|
"""
|
|
144
|
-
Parse init and construct in ast.ClassDef.
|
|
277
|
+
Parse init and entry function(default: construct) in ast.ClassDef.
|
|
145
278
|
|
|
146
279
|
Args:
|
|
147
280
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
148
281
|
node ([ast.ClassDef]): An ast.ClassDef node.
|
|
282
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
149
283
|
"""
|
|
284
|
+
# Update network's class name from xxx to xxxOpt in ast
|
|
150
285
|
replacer = AstReplacer(node)
|
|
151
286
|
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
|
152
287
|
|
|
288
|
+
# process network's father classes
|
|
153
289
|
stree.set_class_ast(node)
|
|
154
|
-
|
|
290
|
+
cur_class_def = type(stree.get_origin_network())
|
|
291
|
+
father_classes = ClassDefParser._process_father_classes(stree, node, cur_class_def)
|
|
155
292
|
|
|
156
|
-
if
|
|
157
|
-
|
|
293
|
+
# add __init__ function to network if necessary
|
|
294
|
+
if isinstance(stree.get_origin_network(), Cell) and ClassDefParser._need_add_init_func(node):
|
|
295
|
+
ClassDefParser._add_init_func(node)
|
|
158
296
|
|
|
159
|
-
|
|
297
|
+
# save function defs in ast node to filter function class variables.
|
|
298
|
+
function_defs = []
|
|
299
|
+
for body in node.body[:]:
|
|
160
300
|
if isinstance(body, ast.FunctionDef):
|
|
301
|
+
function_defs.append(body.name)
|
|
302
|
+
ClassDefParser._func_decorator_process(body)
|
|
161
303
|
if body.name == "__init__":
|
|
162
|
-
self._process_init_func_ast(body, has_father_class)
|
|
163
304
|
stree.set_init_func_ast(body)
|
|
164
|
-
|
|
305
|
+
ClassDefParser._process_init_func_ast(body, stree.get_opt_cls_name(), False, father_classes)
|
|
306
|
+
elif body.name == ClassDefParser.entry_function:
|
|
307
|
+
stree.set_ast_root(body)
|
|
165
308
|
parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
166
|
-
parser.process(stree, body)
|
|
309
|
+
parser.process(stree, body, stree)
|
|
167
310
|
else:
|
|
168
311
|
logger.info(
|
|
169
312
|
"Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s",
|
|
170
313
|
body.name)
|
|
314
|
+
elif isinstance(body, (ast.Assign, ast.If, ast.IfExp)):
|
|
315
|
+
# Remove class variables, which are copied in __init__ function.
|
|
316
|
+
node.body.remove(body)
|
|
171
317
|
else:
|
|
172
318
|
logger.info("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
op = getattr(ori_net, field)
|
|
176
|
-
return not type(op).__name__ in self._cell_namespace
|
|
177
|
-
|
|
178
|
-
def _process_init_func_ast(self, init_ast: ast.FunctionDef, has_father_class: bool):
|
|
179
|
-
"""Process init func"""
|
|
180
|
-
ClassDefParser._modify_arguments_of_init_func(init_ast)
|
|
181
|
-
new_bodies = self._replace_ori_field_of_init_func(init_ast.body, has_father_class)
|
|
182
|
-
init_ast.body = new_bodies
|
|
183
|
-
|
|
184
|
-
def _need_add_init_func(self, stree: SymbolTree, cls_ast: ast.ClassDef) -> bool:
|
|
185
|
-
"""If class is child class of nn.Cell but not have init func, then we need to add init func"""
|
|
186
|
-
if not isinstance(stree.get_origin_network(), Cell):
|
|
187
|
-
return False
|
|
188
|
-
for body in cls_ast.body:
|
|
189
|
-
if isinstance(body, ast.FunctionDef) and body.name == '__init__':
|
|
190
|
-
return False
|
|
191
|
-
return True
|
|
192
|
-
|
|
193
|
-
def _add_init_func(self, cls_ast: ast.ClassDef):
|
|
194
|
-
"""Add init func with super().__init__()"""
|
|
195
|
-
init_func_ast = ast.parse("def __init__(self): super().__init__()").body[0]
|
|
196
|
-
cls_ast.body.insert(0, init_func_ast)
|
|
197
|
-
ast.fix_missing_locations(cls_ast)
|
|
198
|
-
|
|
199
|
-
def _replace_ori_field_of_init_func(self, bodies: [], has_father_class: bool):
|
|
200
|
-
"""
|
|
201
|
-
Replace original field in init func to self.XX = getattr(self._handler, "XX").
|
|
202
|
-
Only keep following two kinds of ast nodes in bodies right now:
|
|
203
|
-
1. Ast.If and test is self.XX.
|
|
204
|
-
2. Ast.Assign and target is self.XX.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
bodies ([]): bodied of init ast.FunctionDef.
|
|
208
|
-
has_father_class (bool): whether class has father class that is not nn.Cell
|
|
209
|
-
|
|
210
|
-
Raises:
|
|
211
|
-
RuntimeError: Not support multi-targets in assign.
|
|
212
|
-
RuntimeError: Only support target.value in [ast.Name] in assign node.
|
|
213
|
-
"""
|
|
214
|
-
new_bodies = []
|
|
215
|
-
for body in bodies:
|
|
216
|
-
if self._is_super_expr(body):
|
|
217
|
-
if has_father_class:
|
|
218
|
-
body.value.args = [ast.Name(id='obj', ctx=ast.Load())]
|
|
219
|
-
body.value.keywords = []
|
|
220
|
-
new_bodies.append(body)
|
|
221
|
-
continue
|
|
222
|
-
ast_copy_attr = ast.parse(
|
|
223
|
-
"for key, value in obj.__dict__.items(): setattr(self, key, value)").body[0]
|
|
224
|
-
new_bodies.append(ast_copy_attr)
|
|
225
|
-
return new_bodies
|
|
226
|
-
|
|
227
|
-
def _handle_father_class(self, stree, node: ast.ClassDef) -> bool:
|
|
228
|
-
"""Handle father class."""
|
|
229
|
-
has_father_class = False
|
|
230
|
-
for base in node.bases:
|
|
231
|
-
parser: Parser = ParserRegister.instance().get_parser(type(base))
|
|
232
|
-
father_class = parser.process(stree, base)
|
|
233
|
-
if "Cell" not in father_class:
|
|
234
|
-
for k, m in sys.modules.items():
|
|
235
|
-
if k in ("_ast", "ast"):
|
|
236
|
-
continue
|
|
237
|
-
if hasattr(m, father_class):
|
|
238
|
-
cls = getattr(m, father_class)
|
|
239
|
-
if not inspect.isclass(cls):
|
|
240
|
-
continue
|
|
241
|
-
source_code = inspect.getsource(cls)
|
|
242
|
-
father_class_ast: ast.Module = ast.parse(source_code)
|
|
243
|
-
self._father_class_process_init_func_ast(stree, father_class_ast)
|
|
244
|
-
stree._father_class_ast.append(father_class_ast) # pylint: disable=protected-access
|
|
245
|
-
has_father_class = True
|
|
246
|
-
break
|
|
247
|
-
return has_father_class
|
|
248
|
-
|
|
249
|
-
def _father_class_process_init_func_ast(self, stree: SymbolTree, father_class_ast: ast.Module):
|
|
250
|
-
father_class_stree: SymbolTree = SymbolTree(stree.get_origin_network(), father_class_ast)
|
|
251
|
-
for ast_body in father_class_ast.body:
|
|
252
|
-
if isinstance(ast_body, ast.ClassDef):
|
|
253
|
-
has_father_class = self._handle_father_class(stree, ast_body)
|
|
254
|
-
if self._need_add_init_func(father_class_stree, ast_body):
|
|
255
|
-
self._add_init_func(ast_body)
|
|
256
|
-
for body in ast_body.body:
|
|
257
|
-
if isinstance(body, ast.FunctionDef) and body.name == "__init__":
|
|
258
|
-
self._process_init_func_ast(body, has_father_class)
|
|
259
|
-
|
|
319
|
+
# Copy function class variables into new network
|
|
320
|
+
ClassDefParser._process_class_variables(stree, function_defs)
|
|
260
321
|
|
|
261
322
|
g_classdef_parser = reg_parser(ClassDefParser())
|
|
@@ -15,11 +15,11 @@
|
|
|
15
15
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
16
|
import ast
|
|
17
17
|
|
|
18
|
-
from mindspore.rewrite.parser import Parser
|
|
18
|
+
from mindspore.rewrite.parsers.parser import Parser
|
|
19
19
|
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
20
|
-
from mindspore.rewrite.parser_register import reg_parser
|
|
20
|
+
from mindspore.rewrite.parsers.parser_register import reg_parser
|
|
21
21
|
from ..common import error_str
|
|
22
|
-
|
|
22
|
+
from ..node.node_manager import NodeManager
|
|
23
23
|
|
|
24
24
|
class NameParser(Parser):
|
|
25
25
|
"""Parse ast.Name in construct function to node of SymbolTree."""
|
|
@@ -28,13 +28,14 @@ class NameParser(Parser):
|
|
|
28
28
|
"""Parse target type."""
|
|
29
29
|
return ast.Name
|
|
30
30
|
|
|
31
|
-
def process(self, stree: SymbolTree, node: ast.Name):
|
|
31
|
+
def process(self, stree: SymbolTree, node: ast.Name, node_manager: NodeManager):
|
|
32
32
|
"""
|
|
33
33
|
Parse ast.Name node.
|
|
34
34
|
|
|
35
35
|
Args:
|
|
36
36
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
37
37
|
node ([ast.Name]): An ast.Name node.
|
|
38
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
38
39
|
|
|
39
40
|
Raises:
|
|
40
41
|
TypeError: Name parser only supports parsing ast.Name type nodes.
|
|
@@ -52,13 +53,14 @@ class NumParser(Parser):
|
|
|
52
53
|
"""Parse target type."""
|
|
53
54
|
return ast.Num
|
|
54
55
|
|
|
55
|
-
def process(self, stree: SymbolTree, node: ast.Num):
|
|
56
|
+
def process(self, stree: SymbolTree, node: ast.Num, node_manager: NodeManager):
|
|
56
57
|
"""
|
|
57
58
|
Parse ast.Num node.
|
|
58
59
|
|
|
59
60
|
Args:
|
|
60
61
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
61
62
|
node ([ast.Num]): An ast.Num node.
|
|
63
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
62
64
|
|
|
63
65
|
Raises:
|
|
64
66
|
TypeError: Num parser only supports parsing ast.Num type nodes.
|
|
@@ -76,13 +78,14 @@ class StrParser(Parser):
|
|
|
76
78
|
"""Parse target type."""
|
|
77
79
|
return ast.Str
|
|
78
80
|
|
|
79
|
-
def process(self, stree: SymbolTree, node: ast.Str):
|
|
81
|
+
def process(self, stree: SymbolTree, node: ast.Str, node_manager: NodeManager):
|
|
80
82
|
"""
|
|
81
83
|
Parse ast.Str node.
|
|
82
84
|
|
|
83
85
|
Args:
|
|
84
86
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
85
87
|
node ([ast.Str]): An ast.Str node.
|
|
88
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
86
89
|
|
|
87
90
|
Returns:
|
|
88
91
|
The value of node.
|
|
@@ -15,12 +15,12 @@
|
|
|
15
15
|
"""Parse Container in construct function to node of SymbolTree."""
|
|
16
16
|
import ast
|
|
17
17
|
|
|
18
|
-
from mindspore.rewrite.parser import Parser
|
|
18
|
+
from mindspore.rewrite.parsers.parser import Parser
|
|
19
19
|
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
20
|
-
from mindspore.rewrite.parser_register import ParserRegister
|
|
20
|
+
from mindspore.rewrite.parsers.parser_register import ParserRegister, reg_parser
|
|
21
21
|
|
|
22
|
-
from mindspore.rewrite.parser_register import reg_parser
|
|
23
22
|
from ..common import error_str
|
|
23
|
+
from ..node.node_manager import NodeManager
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class ListParser(Parser):
|
|
@@ -30,13 +30,14 @@ class ListParser(Parser):
|
|
|
30
30
|
"""Parse target type."""
|
|
31
31
|
return list
|
|
32
32
|
|
|
33
|
-
def process(self, stree: SymbolTree, node: list):
|
|
33
|
+
def process(self, stree: SymbolTree, node: list, node_manager: NodeManager):
|
|
34
34
|
"""
|
|
35
35
|
Parse list.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
39
39
|
node ([list]): An list of node.
|
|
40
|
+
father_node_managernode (NodeManager): NodeManager those asts belong to.
|
|
40
41
|
|
|
41
42
|
Returns:
|
|
42
43
|
A list of value.
|
|
@@ -50,7 +51,7 @@ class ListParser(Parser):
|
|
|
50
51
|
result = []
|
|
51
52
|
for n in node:
|
|
52
53
|
parser = ParserRegister.instance().get_parser(type(n))
|
|
53
|
-
value = parser.process(stree, n)
|
|
54
|
+
value = parser.process(stree, n, node_manager)
|
|
54
55
|
result.append(value)
|
|
55
56
|
return result
|
|
56
57
|
|
|
@@ -62,13 +63,14 @@ class TupleParser(Parser):
|
|
|
62
63
|
"""Parse target type."""
|
|
63
64
|
return tuple
|
|
64
65
|
|
|
65
|
-
def process(self, stree: SymbolTree, node: tuple):
|
|
66
|
+
def process(self, stree: SymbolTree, node: tuple, node_manager: NodeManager):
|
|
66
67
|
"""
|
|
67
68
|
Parse tuple.
|
|
68
69
|
|
|
69
70
|
Args:
|
|
70
71
|
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
71
72
|
node ([tuple]): An tuple of node.
|
|
73
|
+
node_manager (NodeManager): NodeManager those asts belong to.
|
|
72
74
|
|
|
73
75
|
Returns:
|
|
74
76
|
A tuple of value.
|
|
@@ -79,7 +81,7 @@ class TupleParser(Parser):
|
|
|
79
81
|
result = []
|
|
80
82
|
for n in node:
|
|
81
83
|
parser = ParserRegister.instance().get_parser(type(n))
|
|
82
|
-
value = parser.process(stree, n)
|
|
84
|
+
value = parser.process(stree, n, node_manager)
|
|
83
85
|
result.append(value)
|
|
84
86
|
return tuple(result)
|
|
85
87
|
|
|
@@ -13,17 +13,23 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
""" Parse ast.For node """
|
|
16
|
+
import sys
|
|
16
17
|
import ast
|
|
17
|
-
import astunparse
|
|
18
18
|
|
|
19
19
|
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
|
20
20
|
from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier
|
|
21
21
|
from mindspore import log as logger
|
|
22
22
|
from mindspore import nn
|
|
23
23
|
from ..symbol_tree import SymbolTree
|
|
24
|
-
from
|
|
25
|
-
from
|
|
24
|
+
from .parser import Parser
|
|
25
|
+
from .parser_register import reg_parser
|
|
26
26
|
from ..common.event import Event
|
|
27
|
+
from ..node.node_manager import NodeManager
|
|
28
|
+
|
|
29
|
+
if sys.version_info >= (3, 9):
|
|
30
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
31
|
+
else:
|
|
32
|
+
import astunparse
|
|
27
33
|
|
|
28
34
|
EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(")
|
|
29
35
|
|
|
@@ -34,20 +40,20 @@ class ForParser(Parser):
|
|
|
34
40
|
@staticmethod
|
|
35
41
|
def modify_init_ast(stree, i, obj, iter_var_name):
|
|
36
42
|
"""Modify the ast node in init function."""
|
|
37
|
-
target = f"{iter_var_name.strip()}
|
|
43
|
+
target = f"{iter_var_name.strip()}{str(i)}"
|
|
38
44
|
setattr(stree.get_origin_network(), target, obj)
|
|
39
45
|
stree.get_origin_network().insert_child_to_cell(target, obj)
|
|
40
46
|
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
|
|
41
47
|
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
|
|
42
48
|
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
43
49
|
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
44
|
-
ScopedValue(ValueType.
|
|
50
|
+
ScopedValue(ValueType.ConstantValue, "", target)])
|
|
45
51
|
|
|
46
52
|
@staticmethod
|
|
47
53
|
def modify_construct_ast(stree, ast_node, old_name, new_name):
|
|
48
54
|
"""Modify the ast node in construct function."""
|
|
49
55
|
node_str: str = astunparse.unparse(ast_node)
|
|
50
|
-
node_str = node_str.replace(old_name, new_name)
|
|
56
|
+
node_str = node_str.replace(old_name+'(', new_name+'(')
|
|
51
57
|
module_node = ast.parse(node_str)
|
|
52
58
|
new_node = module_node.body[0]
|
|
53
59
|
return new_node
|
|
@@ -55,10 +61,15 @@ class ForParser(Parser):
|
|
|
55
61
|
def target(self):
|
|
56
62
|
return ast.For
|
|
57
63
|
|
|
58
|
-
def process(self, stree: SymbolTree, node: ast.For):
|
|
64
|
+
def process(self, stree: SymbolTree, node: ast.For, node_manager: NodeManager):
|
|
59
65
|
""" Process ast.For node """
|
|
60
66
|
if isinstance(node.target, ast.Name):
|
|
61
67
|
targets = node.target.id
|
|
68
|
+
if isinstance(node.iter, ast.Str) or (isinstance(node.iter, ast.Constant) and
|
|
69
|
+
isinstance(node.iter.val, str)):
|
|
70
|
+
# Ast.For which has iter with type of str is converted to python node to avoid instruction injection
|
|
71
|
+
stree.try_append_python_node(node, node)
|
|
72
|
+
return
|
|
62
73
|
iter_code = astunparse.unparse(node.iter)
|
|
63
74
|
if not iter_code.startswith(EVAL_WHITE_LIST):
|
|
64
75
|
logger.warning(
|
|
@@ -72,26 +83,36 @@ class ForParser(Parser):
|
|
|
72
83
|
_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
|
|
73
84
|
f"an error occurred: {str(e)}"
|
|
74
85
|
logger.warning(_info)
|
|
75
|
-
stree.try_append_python_node(node, node)
|
|
86
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
76
87
|
return
|
|
77
88
|
|
|
78
89
|
iter_var_name = iter_code.split(".")[-1]
|
|
79
|
-
|
|
80
|
-
if
|
|
90
|
+
ast_functiondef = node_manager.get_ast_functiondef()
|
|
91
|
+
if not ast_functiondef:
|
|
92
|
+
logger.info(f"ast_functiondef is None in node_manager {node_manager.get_manager_name()} "
|
|
93
|
+
"when parsing 'for' statement.")
|
|
94
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
95
|
+
return
|
|
96
|
+
index = ast_functiondef.body.index(node) + 1
|
|
97
|
+
if isinstance(iter_obj, (list, nn.CellList)):
|
|
81
98
|
for obj in iter_obj:
|
|
82
99
|
if not isinstance(obj, nn.Cell):
|
|
83
|
-
stree.try_append_python_node(node, node)
|
|
100
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
84
101
|
return
|
|
85
102
|
for i, obj in enumerate(iter_obj):
|
|
86
103
|
ForParser.modify_init_ast(stree, i, obj, iter_var_name)
|
|
87
104
|
for body in node.body:
|
|
88
|
-
new_func_name = f"self.{iter_var_name.strip()}
|
|
105
|
+
new_func_name = f"self.{iter_var_name.strip()}{str(i)}".strip()
|
|
89
106
|
new_node = ForParser.modify_construct_ast(stree, body, targets, new_func_name)
|
|
90
|
-
|
|
107
|
+
ast_functiondef.body.insert(index, new_node)
|
|
91
108
|
index += 1
|
|
109
|
+
# Expand "for" statement and replace the body with Pass
|
|
110
|
+
for body in node.body[:]:
|
|
111
|
+
node.body.remove(body)
|
|
112
|
+
node.body.append(ast.Pass())
|
|
113
|
+
|
|
92
114
|
if stree.get_ori_cls_name() == "SequentialCell":
|
|
93
115
|
stree.on_change(Event.CodeChangeEvent)
|
|
94
|
-
stree.get_ast_root().body.remove(node)
|
|
95
116
|
return
|
|
96
117
|
if isinstance(iter_obj, range):
|
|
97
118
|
logger.warning("For MindSpore Rewrite, range not support.")
|
|
@@ -101,7 +122,7 @@ class ForParser(Parser):
|
|
|
101
122
|
logger.warning("For MindSpore Rewrite, enumerate not support.")
|
|
102
123
|
else:
|
|
103
124
|
logger.warning(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
|
|
104
|
-
stree.try_append_python_node(node, node)
|
|
125
|
+
stree.try_append_python_node(node, node, node_manager)
|
|
105
126
|
return
|
|
106
127
|
|
|
107
128
|
g_for_parser = reg_parser(ForParser())
|