mindspore 2.1.0__cp39-none-any.whl → 2.2.10__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/rewrite/namer.py
CHANGED
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Unique name producer for target, name of node, class name, etc."""
|
|
16
16
|
|
|
17
|
-
from typing import Union
|
|
17
|
+
from typing import Union, Tuple
|
|
18
18
|
|
|
19
|
-
from .node import Node
|
|
19
|
+
from .node.node import Node
|
|
20
20
|
from .api.node_type import NodeType
|
|
21
21
|
|
|
22
22
|
|
|
@@ -33,7 +33,7 @@ class Namer:
|
|
|
33
33
|
self._names: {str: int} = {}
|
|
34
34
|
|
|
35
35
|
@staticmethod
|
|
36
|
-
def _real_name(name: str) -> str:
|
|
36
|
+
def _real_name(name: str) -> Tuple[str, int]:
|
|
37
37
|
"""
|
|
38
38
|
Find real name. For example, "name1" is the real name of "name1_10", "name1" is the real name of "name1_10_3".
|
|
39
39
|
If not find real name before find unique name, unique name may be not unique. For example:
|
|
@@ -47,21 +47,21 @@ class Namer:
|
|
|
47
47
|
name (str): Origin name which may have digit prefix.
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
50
|
-
A string represents real-name.
|
|
50
|
+
A string represents real-name and a int represents suffix.
|
|
51
51
|
"""
|
|
52
52
|
if name == '_':
|
|
53
|
-
return name
|
|
53
|
+
return name, None
|
|
54
54
|
pos = name.rfind("_")
|
|
55
|
-
if pos == -1:
|
|
56
|
-
return name
|
|
55
|
+
if pos == -1 or pos == len(name) - 1:
|
|
56
|
+
return name, None
|
|
57
57
|
digit = True
|
|
58
58
|
for i in range(pos + 1, len(name)):
|
|
59
59
|
if not name[i].isdigit():
|
|
60
60
|
digit = False
|
|
61
61
|
break
|
|
62
62
|
if digit:
|
|
63
|
-
return
|
|
64
|
-
return name
|
|
63
|
+
return name[:pos], int(name[pos + 1:])
|
|
64
|
+
return name, None
|
|
65
65
|
|
|
66
66
|
def get_name(self, origin_name: str) -> str:
|
|
67
67
|
"""
|
|
@@ -75,15 +75,28 @@ class Namer:
|
|
|
75
75
|
"""
|
|
76
76
|
if origin_name == '_':
|
|
77
77
|
return origin_name
|
|
78
|
-
|
|
79
|
-
|
|
78
|
+
real_name, suffix_idx = Namer._real_name(origin_name)
|
|
79
|
+
name = origin_name
|
|
80
|
+
number = self._names.get(name)
|
|
80
81
|
if number is None:
|
|
81
|
-
self._names[
|
|
82
|
-
|
|
82
|
+
self._names[name] = 1
|
|
83
|
+
if not suffix_idx:
|
|
84
|
+
# When _names is {x:2} and origin_name is y,
|
|
85
|
+
# origin_name is not in _names and can be returned.
|
|
86
|
+
return name
|
|
87
|
+
if suffix_idx and not self._names.get(real_name, -1) >= suffix_idx:
|
|
88
|
+
# When _names is {x:2} and origin_name is x_3,
|
|
89
|
+
# return x_3 and update _names to {x:2, x_3:1}
|
|
90
|
+
return name
|
|
91
|
+
# When _names is {x:2} and origin_name is x_1,
|
|
92
|
+
# set new_name to x_1_1 by set number to 1, and continue to update name.
|
|
93
|
+
number = 1
|
|
83
94
|
while True:
|
|
84
|
-
new_name = f"{
|
|
95
|
+
new_name = f"{name}_{number}"
|
|
85
96
|
number += 1
|
|
86
|
-
self._names[
|
|
97
|
+
self._names[name] = number
|
|
98
|
+
# When _names is {x:2, x_3:1}, origin_name is x and number is update to 3,
|
|
99
|
+
# new_name x_3 is conflict with key x_3, so this new_name need to be skipped.
|
|
87
100
|
if new_name in self._names.keys():
|
|
88
101
|
continue
|
|
89
102
|
return new_name
|
|
@@ -141,16 +154,12 @@ class NodeNamer(Namer):
|
|
|
141
154
|
if origin_name is None or not origin_name:
|
|
142
155
|
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.CallFunction,
|
|
143
156
|
NodeType.Tree):
|
|
144
|
-
|
|
145
|
-
raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
|
|
146
|
-
targets = node_or_name.get_targets()
|
|
147
|
-
# return node and head node will not call this method
|
|
148
|
-
if not targets:
|
|
149
|
-
raise RuntimeError("node should has at lease one target except return-node and head-node: ",
|
|
150
|
-
node_or_name)
|
|
151
|
-
origin_name = str(targets[0].value)
|
|
157
|
+
origin_name = type(node_or_name.get_instance()).__name__
|
|
152
158
|
elif node_or_name.get_node_type() == NodeType.Python:
|
|
153
|
-
|
|
159
|
+
if node_or_name.get_instance():
|
|
160
|
+
origin_name = type(node_or_name.get_instance()).__name__
|
|
161
|
+
else:
|
|
162
|
+
origin_name = "python_node"
|
|
154
163
|
elif node_or_name.get_node_type() == NodeType.Input:
|
|
155
164
|
origin_name = "parameter"
|
|
156
165
|
elif node_or_name.get_node_type() == NodeType.Output:
|
mindspore/rewrite/namespace.py
CHANGED
|
@@ -21,12 +21,21 @@ _ms_nn_ns = CellNamespace('mindspore.nn')
|
|
|
21
21
|
_ms_ops_ns = CellNamespace('mindspore.ops.operations')
|
|
22
22
|
_ms_functional_ns = CellNamespace('mindspore.ops.functional')
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
24
|
+
# Elements in _subtree_black_list will not be converted to symbol tree.
|
|
25
|
+
# Only str and types are stored in _subtree_black_list.
|
|
26
|
+
_subtree_black_list = ["QuantizeWrapperCell",]
|
|
27
|
+
|
|
28
|
+
def is_subtree(cls_inst):
|
|
29
|
+
"""Determine whether 'cls_inst' is a subtree."""
|
|
30
|
+
cls_name = type(cls_inst).__name__
|
|
31
|
+
black_list_types = tuple([elem for elem in _subtree_black_list if not isinstance(elem, str)])
|
|
32
|
+
if cls_name in _subtree_black_list or isinstance(cls_inst, black_list_types):
|
|
33
|
+
return False
|
|
34
|
+
if cls_name in _ms_common_ns and isinstance(cls_inst, _ms_common_ns[cls_name]):
|
|
35
|
+
return False
|
|
36
|
+
if cls_name in _ms_nn_ns and isinstance(cls_inst, _ms_nn_ns[cls_name]):
|
|
28
37
|
return False
|
|
29
|
-
if cls_name in
|
|
38
|
+
if cls_name in _ms_ops_ns and isinstance(cls_inst, _ms_ops_ns[cls_name]):
|
|
30
39
|
return False
|
|
31
40
|
|
|
32
41
|
return True
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,11 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
from .
|
|
21
|
-
from .
|
|
22
|
-
from .
|
|
15
|
+
"""
|
|
16
|
+
SymbolTree node
|
|
17
|
+
"""
|
|
18
|
+
from mindspore.rewrite.node.node import Node, TreeNode
|
|
19
|
+
from mindspore.rewrite.node.node_manager import NodeManager
|
|
20
|
+
from mindspore.rewrite.node.call_function import CallFunction
|
|
21
|
+
from mindspore.rewrite.node.cell_container import CellContainer
|
|
22
|
+
from mindspore.rewrite.node.control_flow import ControlFlow
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""CallFunction Node."""
|
|
16
|
+
import ast
|
|
17
|
+
from .node import Node
|
|
18
|
+
from .node_manager import NodeManager
|
|
19
|
+
from ..api.scoped_value import ScopedValue
|
|
20
|
+
from ..api.node_type import NodeType
|
|
21
|
+
from ..ast_helpers import AstModifier
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CallFunction(Node, NodeManager):
|
|
25
|
+
"""CallFunction is used for class internal function."""
|
|
26
|
+
def __init__(self, targets: [ScopedValue], func_name: ScopedValue, args: [ScopedValue],
|
|
27
|
+
kwargs: {str: ScopedValue}, node_name: str, ast_node: ast.AST, ast_functiondef: ast.FunctionDef,
|
|
28
|
+
stree, instance):
|
|
29
|
+
"""
|
|
30
|
+
Constructor of CallFunction.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
34
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
35
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
36
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
37
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
38
|
+
SymbolTree. Name of node also used as field name in network class.
|
|
39
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
40
|
+
ast_functiondef (ast.FunctionDef): An instance of ast.FunctionDef represents corresponding function
|
|
41
|
+
definition in ast.
|
|
42
|
+
stree (SymbolTree): Symbol tree used to get node_namer.
|
|
43
|
+
instance: Object in network corresponding to this node.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(func_name, str):
|
|
46
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
47
|
+
Node.__init__(self, NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, instance)
|
|
48
|
+
NodeManager.__init__(self, stree.get_node_namer())
|
|
49
|
+
NodeManager.set_ast_functiondef(self, ast_functiondef)
|
|
50
|
+
NodeManager.set_manager_name(self, func_name.value)
|
|
51
|
+
|
|
52
|
+
def erase_node(self, node):
|
|
53
|
+
"""Erase node from CallFunction."""
|
|
54
|
+
NodeManager.erase_node(self, node)
|
|
55
|
+
# erase asts
|
|
56
|
+
ret = AstModifier.erase_ast_from_function(self.get_ast_functiondef(), node.get_ast())
|
|
57
|
+
if not ret:
|
|
58
|
+
raise ValueError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
59
|
+
|
|
60
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
|
|
61
|
+
"""
|
|
62
|
+
Insert a node before or after base_node.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
new_node (Node): Node to be inserted.
|
|
66
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
67
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
68
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
69
|
+
"""
|
|
70
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
71
|
+
if insert_to_ast:
|
|
72
|
+
stree = self.get_belong_symbol_tree()
|
|
73
|
+
stree.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
|
|
74
|
+
|
|
75
|
+
def set_belong_symbol_tree(self, symbol_tree):
|
|
76
|
+
"""Set the symbol tree to which node belongs."""
|
|
77
|
+
self._belong_tree = symbol_tree
|
|
78
|
+
for node in self.nodes():
|
|
79
|
+
node.set_belong_symbol_tree(symbol_tree)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""CellContainer Node."""
|
|
16
|
+
import ast
|
|
17
|
+
from mindspore import log as logger
|
|
18
|
+
from .node import Node
|
|
19
|
+
from .node_manager import NodeManager
|
|
20
|
+
from ..api.scoped_value import ScopedValue
|
|
21
|
+
from ..api.node_type import NodeType
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CellContainer(Node, NodeManager):
|
|
25
|
+
"""CellContainer is used for nn.SequencialCell."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func_name: ScopedValue,
|
|
28
|
+
args: [ScopedValue], kwargs: {str: ScopedValue}, node_name: str, stree, instance):
|
|
29
|
+
"""Constructor of CellContainer.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
33
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
34
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
35
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
36
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
37
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
38
|
+
SymbolTree. Name of node also used as field name in network class.
|
|
39
|
+
stree (SymbolTree): Symbol tree used to get node_namer.
|
|
40
|
+
instance: Object in network corresponding to this node.
|
|
41
|
+
"""
|
|
42
|
+
if isinstance(func_name, str):
|
|
43
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
44
|
+
Node.__init__(self, NodeType.CellContainer, ast_node, targets, func_name, args, kwargs, node_name, instance)
|
|
45
|
+
NodeManager.__init__(self, stree.get_node_namer())
|
|
46
|
+
NodeManager.set_manager_name(self, func_name.value)
|
|
47
|
+
|
|
48
|
+
def append(self, node, insert_to_ast: bool = True):
|
|
49
|
+
""" Append new node to node list. """
|
|
50
|
+
self.append_node(node, insert_to_ast)
|
|
51
|
+
|
|
52
|
+
def append_node(self, node, insert_to_ast: bool = True):
|
|
53
|
+
""" Append new node to node list. """
|
|
54
|
+
self.insert_node(node, self.get_tail(), False, insert_to_ast)
|
|
55
|
+
|
|
56
|
+
def erase(self, node):
|
|
57
|
+
"""Erase node from container."""
|
|
58
|
+
self.erase_node(node)
|
|
59
|
+
|
|
60
|
+
def erase_node(self, node):
|
|
61
|
+
"""Erase node from container."""
|
|
62
|
+
# add code `del self.container_name[node_index]` into __init__ function
|
|
63
|
+
_, init_ast_functiondef = self._get_stree_and_init_ast()
|
|
64
|
+
if not init_ast_functiondef:
|
|
65
|
+
logger.error(f"Erase node {node.get_name()} failed: get symboltree and __init__ ast failed.")
|
|
66
|
+
return
|
|
67
|
+
node_idx = self.nodes().index(node)
|
|
68
|
+
erase_code = f"del {self.get_func_name()}[{node_idx}]"
|
|
69
|
+
erase_ast = ast.parse(erase_code).body[0]
|
|
70
|
+
init_ast_functiondef.body.append(erase_ast)
|
|
71
|
+
# earse node in NodeManager
|
|
72
|
+
NodeManager.erase_node(self, node)
|
|
73
|
+
|
|
74
|
+
def insert(self, index, node, insert_to_ast: bool = True):
|
|
75
|
+
"""Insert node into container according index"""
|
|
76
|
+
node_index = index + len(self._inputs)
|
|
77
|
+
if node_index >= self.node_count:
|
|
78
|
+
raise IndexError("In MindSpore Rewrite CellContainer, inserting a node raises index error! "
|
|
79
|
+
f"node_index: {node_index} >= node_num: {self.node_count}")
|
|
80
|
+
self.insert_node(node, self.nodes()[node_index], False, insert_to_ast)
|
|
81
|
+
|
|
82
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
|
|
83
|
+
"""
|
|
84
|
+
Insert a node before or after base_node.
|
|
85
|
+
|
|
86
|
+
The instance is modified here. The scenario needs to be optimized.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
new_node (Node): Node to be inserted.
|
|
90
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
91
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
92
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
93
|
+
"""
|
|
94
|
+
# Insert node to NodeManager firstly to update node_name, which is used during insert ast.
|
|
95
|
+
# tail_node may be changed after insert node into node_manager, so we record tail node here.
|
|
96
|
+
tail_node = self.get_tail()
|
|
97
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
98
|
+
new_node.set_func_name(ScopedValue.create_naming_value(new_node.get_name()))
|
|
99
|
+
new_node.update_ast_node()
|
|
100
|
+
# add insert/append code into __init__ function
|
|
101
|
+
if insert_to_ast:
|
|
102
|
+
stree, init_ast_functiondef = self._get_stree_and_init_ast()
|
|
103
|
+
if not init_ast_functiondef:
|
|
104
|
+
logger.error(f"Insert new_node {new_node.get_name()} failed: get symboltree and __init__ ast failed.")
|
|
105
|
+
return
|
|
106
|
+
setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
|
|
107
|
+
node_idx = self.nodes().index(base_node)
|
|
108
|
+
if before_node:
|
|
109
|
+
insert_code = f"{self.get_func_name()}._insert({node_idx}, self.{new_node.get_name()})"
|
|
110
|
+
else:
|
|
111
|
+
if base_node == tail_node:
|
|
112
|
+
insert_code = f"{self.get_func_name()}.append(self.{new_node.get_name()})"
|
|
113
|
+
else:
|
|
114
|
+
insert_code = f"{self.get_func_name()}._insert({node_idx + 1}, self.{new_node.get_name()})"
|
|
115
|
+
insert_ast = ast.parse(insert_code).body[0]
|
|
116
|
+
init_ast_functiondef.body.append(insert_ast)
|
|
117
|
+
|
|
118
|
+
def set_belong_symbol_tree(self, symbol_tree):
|
|
119
|
+
"""Set the symbol tree to which node belongs."""
|
|
120
|
+
self._belong_tree = symbol_tree
|
|
121
|
+
for node in self.nodes():
|
|
122
|
+
node.set_belong_symbol_tree(symbol_tree)
|
|
123
|
+
|
|
124
|
+
def _get_stree_and_init_ast(self):
|
|
125
|
+
"""Get symbol tree and ast of __init__ function from container."""
|
|
126
|
+
# add codes `del self.container_name[node_index]`` into __init__ function
|
|
127
|
+
stree = self.get_belong_symbol_tree()
|
|
128
|
+
if stree is None:
|
|
129
|
+
logger.error(f"Get symboltree of CellContainer {self.get_name()} failed.")
|
|
130
|
+
return None, None
|
|
131
|
+
init_ast_functiondef = stree.get_init_func_ast()
|
|
132
|
+
if init_ast_functiondef is None:
|
|
133
|
+
logger.error(f"Get ast of __init__ function from class {stree.get_opt_cls_name()} failed.")
|
|
134
|
+
return None, None
|
|
135
|
+
return stree, init_ast_functiondef
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""ControlFlow Node."""
|
|
16
|
+
from typing import List
|
|
17
|
+
import ast
|
|
18
|
+
from .node import Node, TreeNode
|
|
19
|
+
from .node_manager import NodeManager
|
|
20
|
+
from ..api.scoped_value import ScopedValue
|
|
21
|
+
from ..api.node_type import NodeType
|
|
22
|
+
from ..ast_helpers import AstModifier
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ControlFlow(Node, NodeManager):
|
|
26
|
+
"""ControlFlow node is used for statements like loops and `if` ."""
|
|
27
|
+
def __init__(self, node_name: str, ast_body: List[ast.AST], stree):
|
|
28
|
+
"""
|
|
29
|
+
Constructor of ControlFlow.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
33
|
+
SymbolTree. Name of node also used as field name in network class.
|
|
34
|
+
ast_node (ast.AST): An instance of ast.AST represents control flow statements, can be one of ast.If,
|
|
35
|
+
ast.Ifexp, ast.For, ast.While.
|
|
36
|
+
is_orelse (bool): Whether process else branch of node.
|
|
37
|
+
stree (SymbolTree): Symbol tree used to get node_namer.
|
|
38
|
+
"""
|
|
39
|
+
Node.__init__(self, NodeType.ControlFlow, ast_body, None, node_name, [], [], node_name, None)
|
|
40
|
+
NodeManager.__init__(self, stree.get_node_namer())
|
|
41
|
+
NodeManager.set_manager_name(self, node_name)
|
|
42
|
+
self.ast_body = ast_body
|
|
43
|
+
|
|
44
|
+
def erase_node(self, node):
|
|
45
|
+
"""Erase node from container."""
|
|
46
|
+
NodeManager.erase_node(self, node)
|
|
47
|
+
# erase node's ast
|
|
48
|
+
ret = AstModifier.erase_ast_from_bodies(self.ast_body, node.get_ast())
|
|
49
|
+
if not ret:
|
|
50
|
+
raise ValueError(f"Erase node failed, node {node.get_name()} is not in ControlFlow ast tree.")
|
|
51
|
+
|
|
52
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
|
|
53
|
+
"""
|
|
54
|
+
Insert a node before or after base_node.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
new_node (Node): Node to be inserted.
|
|
58
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
59
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
60
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
61
|
+
"""
|
|
62
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
63
|
+
if insert_to_ast:
|
|
64
|
+
ast_assign = new_node.get_ast()
|
|
65
|
+
if ast_assign is None:
|
|
66
|
+
func_name = new_node.get_belong_symbol_tree().unique_func_name(new_node.get_name())
|
|
67
|
+
new_node.set_func_name(ScopedValue.create_naming_value(func_name, "self"))
|
|
68
|
+
ast_assign = new_node.update_ast_node()
|
|
69
|
+
# Save instance into _origin_network.
|
|
70
|
+
stree = self.get_belong_symbol_tree()
|
|
71
|
+
setattr(stree.get_origin_network(), new_node.get_name(), new_node.get_instance())
|
|
72
|
+
# Insert ast_assign to __init__ function
|
|
73
|
+
if isinstance(new_node, TreeNode):
|
|
74
|
+
init_code = f"self.{new_node.get_name()} = " \
|
|
75
|
+
f"{new_node.symbol_tree.get_opt_cls_name()}(obj.{new_node.get_name()})"
|
|
76
|
+
else:
|
|
77
|
+
init_code = f"self.{new_node.get_name()} = obj.{new_node.get_name()}"
|
|
78
|
+
init_ast = ast.parse(init_code).body[0]
|
|
79
|
+
AstModifier.insert_assign_ast_to_function(stree.get_init_func_ast(), init_ast)
|
|
80
|
+
# Insert ast_assign to bodies
|
|
81
|
+
ast_base_node = base_node.get_ast() if base_node else None
|
|
82
|
+
AstModifier.insert_assign_ast_to_bodies(self.ast_body, ast_assign, ast_base_node, before_node)
|
|
83
|
+
|
|
84
|
+
def set_belong_symbol_tree(self, symbol_tree):
|
|
85
|
+
"""Set the symbol tree to which node belongs."""
|
|
86
|
+
self._belong_tree = symbol_tree
|
|
87
|
+
for node in self.nodes():
|
|
88
|
+
node.set_belong_symbol_tree(symbol_tree)
|