mindspore 2.1.0__cp39-none-any.whl → 2.2.11__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -13,8 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: SymbolTree."""
|
|
16
|
-
from typing import Optional, Union
|
|
17
|
-
from types import FunctionType
|
|
16
|
+
from typing import Optional, Union, List
|
|
18
17
|
import mindspore as ms
|
|
19
18
|
|
|
20
19
|
from mindspore.nn import Cell
|
|
@@ -53,7 +52,61 @@ class SymbolTree:
|
|
|
53
52
|
|
|
54
53
|
This interface parses the `network` instance, expands each source
|
|
55
54
|
code statement of the forward computation process, and parses it into nodes,
|
|
56
|
-
which is stored in the SymbolTree.
|
|
55
|
+
which is stored in the SymbolTree. The specific process is as follows:
|
|
56
|
+
|
|
57
|
+
1. Obtain the source code of the network instance.
|
|
58
|
+
2. Perform AST parsing on the network and obtain the AST nodes (abstract syntax trees) of each
|
|
59
|
+
statement in the network.
|
|
60
|
+
3. Expand complex statements in the network forward evaluation process into multiple simple statements.
|
|
61
|
+
4. Create a SymbolTree object. Each SymbolTree corresponds to one network instance.
|
|
62
|
+
5. Use the rewrite node to store each statement of the network forward computation process. The node records
|
|
63
|
+
the input, output, and other information of the statement.
|
|
64
|
+
6. Save the rewrite node to the SymbolTree, and update and maintain the topological connection between
|
|
65
|
+
the nodes.
|
|
66
|
+
7. Return the SymbolTree object corresponding to the network instance.
|
|
67
|
+
|
|
68
|
+
If a user-defined network of type :class:`mindspore.nn.Cell` is called in the forward computation process
|
|
69
|
+
of the network, rewrite will generate a node of type `NodeType.Tree` for the corresponding statement. This
|
|
70
|
+
type of node stores a new SymbolTree, which parses and maintains the node information of the user-defined
|
|
71
|
+
network.
|
|
72
|
+
|
|
73
|
+
If the following types of statements are called in the forward computation process of the network, rewrite
|
|
74
|
+
will parse the internal statements in the statement and generate corresponding nodes:
|
|
75
|
+
|
|
76
|
+
- :class:`mindspore.nn.SequentialCell`
|
|
77
|
+
- Functions within classes
|
|
78
|
+
- Control flow statements, such as `if` statements
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
Because the specific execution branch of control flows are still unknown during the rewrite operation
|
|
82
|
+
of the network, no topology information will be established between the nodes inside the control flow
|
|
83
|
+
and the nodes outside.
|
|
84
|
+
Users cannot obtain nodes inside the control flow when they acquire nodes outside the control flow using
|
|
85
|
+
interfaces like :func:`mindspore.rewrite.Node.get_inputs` and :func:`mindspore.rewrite.Node.get_users` .
|
|
86
|
+
Users also cannot obtain nodes outside the control flow, if they use these interfaces inside the control
|
|
87
|
+
flow.
|
|
88
|
+
Therefore, when users modify the network, they need to manually handle the node information inside and
|
|
89
|
+
outside the control flow.
|
|
90
|
+
|
|
91
|
+
The current rewrite module has the following syntax limitations:
|
|
92
|
+
|
|
93
|
+
- Only networks of type :class:`mindspore.nn.Cell` are supported as input to the rewrite module.
|
|
94
|
+
- Parsing assignment statements with multiple output values is not currently supported.
|
|
95
|
+
- Parsing loop statements is not currently supported.
|
|
96
|
+
- Parsing decorator syntax is not currently supported.
|
|
97
|
+
- Parsing class variable syntax is not currently supported. If class variable uses external data,
|
|
98
|
+
the network after rewrite may be missing data.
|
|
99
|
+
- Parsing local classes and embedded classes is not currently supported, that is, the definition
|
|
100
|
+
of classes need to be placed on the outermost layer.
|
|
101
|
+
- Parsing closure syntax is not currently supported, that is, the definition of out-of-class
|
|
102
|
+
functions need to be placed at the outermost layer.
|
|
103
|
+
- Parsing lambda expression syntax is not currently supported.
|
|
104
|
+
|
|
105
|
+
For statements that do not support parsing, rewrite will generate nodes of type `NodeType.Python`
|
|
106
|
+
for corresponding statements to ensure that the network after rewrite can run normally.
|
|
107
|
+
The `Python` node does not support modifying the input and output of statements, and there may be
|
|
108
|
+
a problem between variable names and those generated by the rewrite. In this case, users need to
|
|
109
|
+
adjust the variable names manually.
|
|
57
110
|
|
|
58
111
|
Args:
|
|
59
112
|
network (Cell): `network` used to create SymbolTree.
|
|
@@ -67,7 +120,7 @@ class SymbolTree:
|
|
|
67
120
|
Examples:
|
|
68
121
|
>>> from mindspore.rewrite import SymbolTree
|
|
69
122
|
>>> # Define the network structure of LeNet5. Refer to
|
|
70
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
123
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
71
124
|
>>> net = LeNet5()
|
|
72
125
|
>>> stree = SymbolTree.create(net)
|
|
73
126
|
>>> print(type(stree))
|
|
@@ -90,42 +143,37 @@ class SymbolTree:
|
|
|
90
143
|
if v not in MsDtypes and not isinstance(v, ParamTypes):
|
|
91
144
|
raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
|
|
92
145
|
|
|
93
|
-
def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
|
|
94
|
-
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
|
|
95
|
-
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
|
|
96
|
-
args_ = list(args)
|
|
97
|
-
SymbolTree._check_args_type(args_)
|
|
98
|
-
for i, arg in enumerate(args_):
|
|
99
|
-
if isinstance(arg, Node):
|
|
100
|
-
args_[i] = arg.get_handler()
|
|
101
|
-
SymbolTree._check_kwargs_type(kwargs)
|
|
102
|
-
for key, value in kwargs.items():
|
|
103
|
-
if isinstance(value, Node):
|
|
104
|
-
kwargs[key] = value.get_handler()
|
|
105
|
-
return Node(self._symbol_tree._create_call_function(func, targets, args_, kwargs)) # pylint: disable=W0212
|
|
106
|
-
|
|
107
146
|
def get_handler(self) -> SymbolTreeImpl:
|
|
108
147
|
return self._symbol_tree
|
|
109
148
|
|
|
110
|
-
def nodes(self):
|
|
149
|
+
def nodes(self, all_nodes: bool = False):
|
|
111
150
|
"""
|
|
112
151
|
Get the generator of the node in the current SymbolTree, which is used to iterate
|
|
113
152
|
through the nodes in SymbolTree.
|
|
114
153
|
|
|
154
|
+
Args:
|
|
155
|
+
all_nodes (bool): Get all nodes including nodes in CallFunction node, CellContainer node
|
|
156
|
+
and sub symbol tree. Default: ``False`` .
|
|
157
|
+
|
|
115
158
|
Returns:
|
|
116
|
-
A generator for
|
|
159
|
+
A generator for nodes in SymbolTree.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
TypeError: If `all_nodes` is not bool.
|
|
117
163
|
|
|
118
164
|
Examples:
|
|
119
165
|
>>> from mindspore.rewrite import SymbolTree
|
|
120
166
|
>>> # Define the network structure of LeNet5. Refer to
|
|
121
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
167
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
122
168
|
>>> net = LeNet5()
|
|
123
169
|
>>> stree = SymbolTree.create(net)
|
|
124
170
|
>>> print([node.get_name() for node in stree.nodes()])
|
|
125
171
|
['input_x', 'Expr', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
|
|
126
172
|
'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return']
|
|
127
173
|
"""
|
|
128
|
-
|
|
174
|
+
Validator.check_value_type("all_nodes", all_nodes, [bool], "nodes")
|
|
175
|
+
nodes = self._symbol_tree.all_nodes() if all_nodes else self._symbol_tree.nodes()
|
|
176
|
+
for node in nodes:
|
|
129
177
|
yield Node(node)
|
|
130
178
|
|
|
131
179
|
def get_node(self, node_name: str) -> Optional[Node]:
|
|
@@ -141,7 +189,7 @@ class SymbolTree:
|
|
|
141
189
|
Examples:
|
|
142
190
|
>>> from mindspore.rewrite import SymbolTree
|
|
143
191
|
>>> # Define the network structure of LeNet5. Refer to
|
|
144
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
192
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
145
193
|
>>> net = LeNet5()
|
|
146
194
|
>>> stree = SymbolTree.create(net)
|
|
147
195
|
>>> node = stree.get_node('conv1')
|
|
@@ -149,12 +197,12 @@ class SymbolTree:
|
|
|
149
197
|
conv1
|
|
150
198
|
"""
|
|
151
199
|
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
|
|
152
|
-
node_impl = self._symbol_tree.
|
|
200
|
+
node_impl = self._symbol_tree.get_node_from_name(node_name)
|
|
153
201
|
if node_impl is None:
|
|
154
202
|
return None
|
|
155
203
|
return Node(node_impl)
|
|
156
204
|
|
|
157
|
-
def get_inputs(self) -> [Node]:
|
|
205
|
+
def get_inputs(self) -> List[Node]:
|
|
158
206
|
return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
|
|
159
207
|
|
|
160
208
|
def before(self, node: Union[Node, str]):
|
|
@@ -174,15 +222,17 @@ class SymbolTree:
|
|
|
174
222
|
Examples:
|
|
175
223
|
>>> from mindspore.rewrite import SymbolTree
|
|
176
224
|
>>> # Define the network structure of LeNet5. Refer to
|
|
177
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
225
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
178
226
|
>>> net = LeNet5()
|
|
179
227
|
>>> stree = SymbolTree.create(net)
|
|
180
228
|
>>> for node in stree.nodes():
|
|
181
229
|
... if node.get_name() == "conv1":
|
|
182
230
|
... position = stree.before(node)
|
|
183
231
|
"""
|
|
184
|
-
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
185
|
-
|
|
232
|
+
Validator.check_value_type("node", node, [Node, str], "SymbolTree")
|
|
233
|
+
if isinstance(node, Node):
|
|
234
|
+
node = node.get_handler()
|
|
235
|
+
return self._symbol_tree.before(node)
|
|
186
236
|
|
|
187
237
|
def after(self, node: Union[Node, str]):
|
|
188
238
|
"""
|
|
@@ -201,15 +251,17 @@ class SymbolTree:
|
|
|
201
251
|
Examples:
|
|
202
252
|
>>> from mindspore.rewrite import SymbolTree
|
|
203
253
|
>>> # Define the network structure of LeNet5. Refer to
|
|
204
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
254
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
205
255
|
>>> net = LeNet5()
|
|
206
256
|
>>> stree = SymbolTree.create(net)
|
|
207
257
|
>>> for node in stree.nodes():
|
|
208
258
|
... if node.get_name() == "conv1":
|
|
209
259
|
... position = stree.after(node)
|
|
210
260
|
"""
|
|
211
|
-
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
212
|
-
|
|
261
|
+
Validator.check_value_type("node", node, [Node, str], "SymbolTree")
|
|
262
|
+
if isinstance(node, Node):
|
|
263
|
+
node = node.get_handler()
|
|
264
|
+
return self._symbol_tree.after(node)
|
|
213
265
|
|
|
214
266
|
def insert(self, position, node: Node) -> Node:
|
|
215
267
|
"""
|
|
@@ -233,7 +285,7 @@ class SymbolTree:
|
|
|
233
285
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
234
286
|
>>> import mindspore.nn as nn
|
|
235
287
|
>>> # Define the network structure of LeNet5. Refer to
|
|
236
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
288
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
237
289
|
>>> net = LeNet5()
|
|
238
290
|
>>> stree = SymbolTree.create(net)
|
|
239
291
|
>>> node = stree.get_node("conv1")
|
|
@@ -244,7 +296,7 @@ class SymbolTree:
|
|
|
244
296
|
"""
|
|
245
297
|
Validator.check_value_type("position", position, [Position], "SymbolTree")
|
|
246
298
|
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
247
|
-
return Node(self._symbol_tree.insert_node(
|
|
299
|
+
return Node(self._symbol_tree.insert_node(node.get_handler(), position.node, position.before_node))
|
|
248
300
|
|
|
249
301
|
def erase(self, node: Union[Node, str]) -> Optional[Node]:
|
|
250
302
|
"""
|
|
@@ -262,16 +314,18 @@ class SymbolTree:
|
|
|
262
314
|
Examples:
|
|
263
315
|
>>> from mindspore.rewrite import SymbolTree
|
|
264
316
|
>>> # Define the network structure of LeNet5. Refer to
|
|
265
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
317
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
266
318
|
>>> net = LeNet5()
|
|
267
319
|
>>> stree = SymbolTree.create(net)
|
|
268
320
|
>>> node = stree.get_node("conv1")
|
|
269
321
|
>>> stree.erase(node)
|
|
270
322
|
"""
|
|
271
|
-
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
272
|
-
|
|
323
|
+
Validator.check_value_type("node", node, [Node, str], "SymbolTree")
|
|
324
|
+
if isinstance(node, Node):
|
|
325
|
+
node = node.get_handler()
|
|
326
|
+
return Node(self._symbol_tree.erase_node(node))
|
|
273
327
|
|
|
274
|
-
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
|
328
|
+
def replace(self, old_node: Node, new_nodes: List[Node]) -> Node:
|
|
275
329
|
"""
|
|
276
330
|
Replace the `old_node` with nodes in the `new_nodes` list.
|
|
277
331
|
|
|
@@ -285,7 +339,7 @@ class SymbolTree:
|
|
|
285
339
|
|
|
286
340
|
Args:
|
|
287
341
|
old_node (Node): Node to be replaced.
|
|
288
|
-
new_nodes (
|
|
342
|
+
new_nodes (List[Node]): Nodes of the node_tree to replace in.
|
|
289
343
|
|
|
290
344
|
Returns:
|
|
291
345
|
An instance of Node represents root of node_tree been replaced in.
|
|
@@ -299,7 +353,7 @@ class SymbolTree:
|
|
|
299
353
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
300
354
|
>>> import mindspore.nn as nn
|
|
301
355
|
>>> # Define the network structure of LeNet5. Refer to
|
|
302
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
356
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
303
357
|
>>> net = LeNet5()
|
|
304
358
|
>>> stree = SymbolTree.create(net)
|
|
305
359
|
>>> node = stree.get_node("conv1")
|
|
@@ -320,16 +374,38 @@ class SymbolTree:
|
|
|
320
374
|
def dump(self):
|
|
321
375
|
self._symbol_tree.dump()
|
|
322
376
|
|
|
323
|
-
def print_node_tabulate(self):
|
|
324
|
-
"""
|
|
377
|
+
def print_node_tabulate(self, all_nodes: bool = False):
|
|
378
|
+
r"""
|
|
325
379
|
Print the topology information of nodes in SymbolTree, including node type, node name, node code,
|
|
326
380
|
and node input-output relationship.
|
|
327
|
-
The information is output to the screen using the print interface.
|
|
328
381
|
|
|
329
|
-
|
|
330
|
-
|
|
382
|
+
The information is output to the screen using the print interface, including the following information:
|
|
383
|
+
|
|
384
|
+
- **node type** (str): The type of node, refer to class:`mindspore.rewrite.NodeType` .
|
|
385
|
+
- **name** (str): The name of node.
|
|
386
|
+
- **codes** (str): The source code statement corresponding to the node.
|
|
387
|
+
- **arg providers** (Dict[int, Tuple[str, int]]): The format is `{[idx, (n, k)]}` , which means the
|
|
388
|
+
`idx` th parameter of the node is provided by the `k` th output of node `n` .
|
|
389
|
+
- **target users** (Dict[int, List[Tuple[str, int]]]): The format is '{[idx, [(n, k)]]}' , which means
|
|
390
|
+
the `idx` th output of the node is used as the `k` th parameter of node `n` .
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
all_nodes (bool): Print information of all nodes, including nodes in CallFunction
|
|
394
|
+
node, CellContainer node and sub symbol tree. Default: ``False`` .
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
TypeError: If `all_nodes` is not bool.
|
|
398
|
+
|
|
399
|
+
Examples:
|
|
400
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
401
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
402
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
403
|
+
>>> net = LeNet5()
|
|
404
|
+
>>> stree = SymbolTree.create(net)
|
|
405
|
+
>>> stree.print_node_tabulate()
|
|
331
406
|
"""
|
|
332
|
-
|
|
407
|
+
Validator.check_value_type("all_nodes", all_nodes, [bool], "print_node_tabulate")
|
|
408
|
+
self._symbol_tree.print_node_tabulate(all_nodes)
|
|
333
409
|
|
|
334
410
|
def get_code(self) -> str:
|
|
335
411
|
"""
|
|
@@ -342,7 +418,7 @@ class SymbolTree:
|
|
|
342
418
|
Examples:
|
|
343
419
|
>>> from mindspore.rewrite import SymbolTree
|
|
344
420
|
>>> # Define the network structure of LeNet5. Refer to
|
|
345
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
421
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
346
422
|
>>> net = LeNet5()
|
|
347
423
|
>>> stree = SymbolTree.create(net)
|
|
348
424
|
>>> codes = stree.get_code()
|
|
@@ -355,13 +431,21 @@ class SymbolTree:
|
|
|
355
431
|
Get the network object generated based on SymbolTree.
|
|
356
432
|
The source code is saved to a file in the 'rewritten_network' folder of the current directory.
|
|
357
433
|
|
|
434
|
+
Note:
|
|
435
|
+
- The modification of network by rewrite module is based on the modification of AST tree of
|
|
436
|
+
original network instance, and the new network instance will obtain attribute information
|
|
437
|
+
from original network instance, so the new network instance and the original network instance
|
|
438
|
+
have data association, and the original network should no longer be used.
|
|
439
|
+
- Due to the data association between the new network and the original network instance, manually creating
|
|
440
|
+
a network instance using the source code file generated by rewrite is not currently supported.
|
|
441
|
+
|
|
358
442
|
Returns:
|
|
359
443
|
A network object generated from SymbolTree.
|
|
360
444
|
|
|
361
445
|
Examples:
|
|
362
446
|
>>> from mindspore.rewrite import SymbolTree
|
|
363
447
|
>>> # Define the network structure of LeNet5. Refer to
|
|
364
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
448
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
365
449
|
>>> net = LeNet5()
|
|
366
450
|
>>> stree = SymbolTree.create(net)
|
|
367
451
|
>>> new_net = stree.get_network()
|
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
Define some ast helpers for manipulating python ast.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj
|
|
20
|
+
from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj, \
|
|
21
|
+
AstAssignFinder, AstClassFinder, AstFunctionFinder
|
|
21
22
|
from .ast_replacer import AstReplacer
|
|
22
23
|
from .ast_modifier import AstModifier
|
|
23
24
|
from .ast_creator import ast_args_creator, ast_assign_creator, ast_attributer_creator, ast_call_creator, \
|
|
@@ -225,3 +225,132 @@ class GetPropertyOfObj(ast.NodeVisitor):
|
|
|
225
225
|
self._property = set()
|
|
226
226
|
self.generic_visit(self._context)
|
|
227
227
|
return self._property
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class AstAssignFinder(ast.NodeVisitor):
|
|
231
|
+
"""
|
|
232
|
+
Get assign definition ast of specifical parameter in specific scope.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
node (ast.AST): An instance of ast node as check scope.
|
|
236
|
+
"""
|
|
237
|
+
def __init__(self, node: ast.AST):
|
|
238
|
+
self._context = node
|
|
239
|
+
self._scope = ""
|
|
240
|
+
self._value = ""
|
|
241
|
+
self._target = None
|
|
242
|
+
|
|
243
|
+
def visit_Assign(self, node: ast.Assign):
|
|
244
|
+
if self._scope and isinstance(node.targets[0], ast.Attribute):
|
|
245
|
+
if node.targets[0].attr == self._value and isinstance(node.targets[0].value, ast.Name) \
|
|
246
|
+
and node.targets[0].value.id == self._scope:
|
|
247
|
+
self._target = node
|
|
248
|
+
elif not self._scope and isinstance(node.targets[0], ast.Name):
|
|
249
|
+
if node.targets[0].id == self._value:
|
|
250
|
+
self._target = node
|
|
251
|
+
|
|
252
|
+
def get_ast(self, value: str, scope: str = "") -> bool:
|
|
253
|
+
"""
|
|
254
|
+
Get assign ast of specifical parameter in specific ast.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
value (str): A string indicates assign target value.
|
|
258
|
+
scope (str): A string indicates assign target scope.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
An assign ast with the same target name as `scope.value` .
|
|
262
|
+
"""
|
|
263
|
+
self._scope = scope
|
|
264
|
+
self._value = value
|
|
265
|
+
self.generic_visit(self._context)
|
|
266
|
+
return self._target
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class AstClassFinder(ast.NodeVisitor):
|
|
270
|
+
"""
|
|
271
|
+
Find all specific name of ast class node in specific scope.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
node (ast.AST): An instance of ast node as search scope.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
def __init__(self, node: ast.AST):
|
|
278
|
+
self._scope: ast.AST = node
|
|
279
|
+
self._target: str = ""
|
|
280
|
+
self._results: [ast.ClassDef] = []
|
|
281
|
+
|
|
282
|
+
def visit_ClassDef(self, node):
|
|
283
|
+
"""
|
|
284
|
+
An override method, iterating over all ClassDef nodes and save target ast nodes.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
node (ast.AST): An instance of ast node which is visited currently.
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
if node.name == self._target:
|
|
291
|
+
self._results.append(node)
|
|
292
|
+
|
|
293
|
+
def find_all(self, class_name: str) -> [ast.AST]:
|
|
294
|
+
"""
|
|
295
|
+
Find all matched ast node.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
class_name (str): Name of class to be found.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
A list of instance of ast.ClassDef as matched result.
|
|
302
|
+
|
|
303
|
+
Raises:
|
|
304
|
+
TypeError: If input `class_name` is not str.
|
|
305
|
+
"""
|
|
306
|
+
if not isinstance(class_name, str):
|
|
307
|
+
raise TypeError("Input class_name should be a str")
|
|
308
|
+
self._target = class_name
|
|
309
|
+
self._results.clear()
|
|
310
|
+
self.visit(self._scope)
|
|
311
|
+
return self._results
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class AstFunctionFinder(ast.NodeVisitor):
|
|
315
|
+
"""
|
|
316
|
+
Find all specific name of ast function node in specific scope.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
node (ast.AST): An instance of ast node as search scope.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(self, node: ast.AST):
|
|
323
|
+
self._scope: ast.AST = node
|
|
324
|
+
self._target: str = ""
|
|
325
|
+
self._results: [ast.ClassDef] = []
|
|
326
|
+
|
|
327
|
+
def visit_FunctionDef(self, node):
|
|
328
|
+
"""
|
|
329
|
+
An override method, iterating over all FunctionDef nodes and save target ast nodes.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
node (ast.AST): An instance of ast node which is visited currently.
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
if node.name == self._target:
|
|
336
|
+
self._results.append(node)
|
|
337
|
+
|
|
338
|
+
def find_all(self, func_name: str) -> [ast.AST]:
|
|
339
|
+
"""
|
|
340
|
+
Find all matched ast node.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
func_name (str): Name of function to be found.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
A list of instance of ast.FunctionDef as matched result.
|
|
347
|
+
|
|
348
|
+
Raises:
|
|
349
|
+
TypeError: If input `func_name` is not str.
|
|
350
|
+
"""
|
|
351
|
+
if not isinstance(func_name, str):
|
|
352
|
+
raise TypeError("Input func_name should be a str")
|
|
353
|
+
self._target = func_name
|
|
354
|
+
self._results.clear()
|
|
355
|
+
self.visit(self._scope)
|
|
356
|
+
return self._results
|