mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +26 -32
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +72 -95
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +173 -258
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +240 -145
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +13 -2
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +143 -59
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +11 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +0 -14
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +316 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +21 -28
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +310 -207
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +82 -41
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +13 -18
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +22 -17
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +78 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +4 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +167 -189
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -8
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +470 -251
- mindspore/ops/function/random_func.py +86 -56
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +235 -19
- mindspore/ops/operations/__init__.py +25 -17
- mindspore/ops/operations/_grad_ops.py +52 -7
- mindspore/ops/operations/_inner_ops.py +213 -12
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +64 -280
- mindspore/ops/operations/comm_ops.py +105 -57
- mindspore/ops/operations/custom_ops.py +10 -3
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/math_ops.py +185 -138
- mindspore/ops/operations/nn_ops.py +716 -492
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +14 -12
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +6 -10
- mindspore/parallel/shard.py +4 -4
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
- mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
- mindspore/profiler/parser/ascend_op_generator.py +5 -5
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
- mindspore/profiler/parser/base_timeline_generator.py +9 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +37 -21
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +2 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +139 -71
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +525 -577
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +2 -2
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +14 -7
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +83 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +185 -45
- mindspore/train/serialization.py +390 -150
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +14 -10
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/profiler/profiling.py
CHANGED
|
@@ -71,15 +71,21 @@ AICORE_METRICS_DICT = {
|
|
|
71
71
|
class DeviceSupportParam(Enum):
|
|
72
72
|
"""The device target enum."""
|
|
73
73
|
CPU = ['start', 'start_profile', 'output_path', 'timeline_limit', 'profile_framework', 'op_time']
|
|
74
|
-
GPU = [
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
74
|
+
GPU = [
|
|
75
|
+
'start', 'start_profile', 'output_path', 'data_process', 'timeline_limit', 'sync_enable', 'op_time',
|
|
76
|
+
'profile_framework'
|
|
77
|
+
]
|
|
78
|
+
ASCEND = [
|
|
79
|
+
'start', 'start_profile', 'output_path', 'data_process', 'timeline_limit', 'profile_memory',
|
|
80
|
+
'parallel_strategy', 'profile_communication', 'aicore_metrics', 'l2_cache', 'op_time', 'ascend_job_id',
|
|
81
|
+
'profile_framework'
|
|
82
|
+
]
|
|
79
83
|
|
|
80
84
|
|
|
81
|
-
ALWAYS_VALID_PARAM = [
|
|
82
|
-
|
|
85
|
+
ALWAYS_VALID_PARAM = [
|
|
86
|
+
'start', 'start_profile', 'output_path', 'data_process', 'parallel_strategy', 'l2_cache',
|
|
87
|
+
'ascend_job_id', 'op_time', 'profile_framework'
|
|
88
|
+
]
|
|
83
89
|
|
|
84
90
|
|
|
85
91
|
def _environment_check():
|
|
@@ -161,6 +167,7 @@ def _calculate_dataset_execution_time(input_file, output_file):
|
|
|
161
167
|
csv_writer.writerow(['Operation', 'Stage', 'Occurrences', 'Avg. time (us)', 'Custom Info'])
|
|
162
168
|
for _, v in execution_time_map.items():
|
|
163
169
|
csv_writer.writerow([v.event, v.stage, v.count, v.average_execution, v.custom_info])
|
|
170
|
+
os.chmod(output_file, modes)
|
|
164
171
|
logger.info('Successfully calculate the execution time and write it to file: %s.', output_file)
|
|
165
172
|
|
|
166
173
|
|
|
@@ -188,8 +195,10 @@ def _extract_timeline_item(row, time_line, ts_map):
|
|
|
188
195
|
# Put the instance event into timeline.
|
|
189
196
|
elif start_end == '2':
|
|
190
197
|
title = row['event'] + '::' + row['stage']
|
|
191
|
-
event = {
|
|
192
|
-
|
|
198
|
+
event = {
|
|
199
|
+
'name': title, 'cat': row['module_name'], 'ts': int(row['time_stamp(us)']), 'ph': 'i',
|
|
200
|
+
'pid': row['pid'], 'tid': row['tid'], 'args': {'parent_pid': row['parent_pid']}
|
|
201
|
+
}
|
|
193
202
|
time_line.append(event)
|
|
194
203
|
else:
|
|
195
204
|
logger.warning("Can not map the start time for item: %s.", row)
|
|
@@ -209,8 +218,10 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
|
|
|
209
218
|
time_line = []
|
|
210
219
|
# ts_map is used to store the start time of each event_stage_tid_pid
|
|
211
220
|
ts_map = {}
|
|
212
|
-
memory_header = [
|
|
213
|
-
|
|
221
|
+
memory_header = [
|
|
222
|
+
'tid', 'pid', 'parent_pid', 'module_name', 'event', 'stage', 'level', 'start_end', 'custom_info',
|
|
223
|
+
'memory_usage(kB)', 'time_stamp(us)'
|
|
224
|
+
]
|
|
214
225
|
memory_info = []
|
|
215
226
|
with open(input_file, 'r') as f:
|
|
216
227
|
for row in csv.DictReader(f):
|
|
@@ -226,12 +237,12 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
|
|
|
226
237
|
logger.error("Error occur when analyse line: %s, Details is: %s", row, e)
|
|
227
238
|
continue
|
|
228
239
|
if memory_info:
|
|
229
|
-
with os.fdopen(os.open(output_memory_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
|
|
230
|
-
'w') as csv_file:
|
|
240
|
+
with os.fdopen(os.open(output_memory_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as csv_file:
|
|
231
241
|
csv_writer = csv.DictWriter(csv_file, fieldnames=memory_header)
|
|
232
242
|
csv_writer.writeheader()
|
|
233
243
|
for item in memory_info:
|
|
234
244
|
csv_writer.writerow(item)
|
|
245
|
+
os.chmod(output_memory_file, stat.S_IREAD | stat.S_IWRITE)
|
|
235
246
|
else:
|
|
236
247
|
logger.warning("No memory_usage is record in file: %s", input_file)
|
|
237
248
|
|
|
@@ -255,13 +266,23 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
|
|
|
255
266
|
|
|
256
267
|
if time_line:
|
|
257
268
|
timeline_file = validate_and_normalize_path(output_timeline_file)
|
|
258
|
-
with os.fdopen(os.open(timeline_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
|
|
259
|
-
'w') as json_file:
|
|
269
|
+
with os.fdopen(os.open(timeline_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as json_file:
|
|
260
270
|
json.dump(time_line, json_file)
|
|
271
|
+
os.chmod(timeline_file, stat.S_IREAD | stat.S_IWRITE)
|
|
261
272
|
else:
|
|
262
273
|
logger.warning("No valid time_stamp is record in file: %s", input_file)
|
|
263
274
|
|
|
264
275
|
|
|
276
|
+
def _ascend_graph_msprof_generator(source_path, model_iteration_dict):
|
|
277
|
+
try:
|
|
278
|
+
msprof_exporter = AscendMsprofExporter(source_path)
|
|
279
|
+
msprof_exporter.export(model_iteration_dict)
|
|
280
|
+
except ProfilerException as err:
|
|
281
|
+
logger.warning(err.message)
|
|
282
|
+
finally:
|
|
283
|
+
pass
|
|
284
|
+
|
|
285
|
+
|
|
265
286
|
def _ascend_graph_msprof_analyse(source_path):
|
|
266
287
|
"""
|
|
267
288
|
Ascend graph model msprof data analyse.
|
|
@@ -287,7 +308,8 @@ class Profiler:
|
|
|
287
308
|
This class to enable the profiling of MindSpore neural networks.
|
|
288
309
|
MindSpore users can import the mindspore.Profiler, initialize the Profiler object to start profiling,
|
|
289
310
|
and use Profiler.analyse() to stop profiling and analyse the results.
|
|
290
|
-
Users can visualize the results using the
|
|
311
|
+
Users can visualize the results using the `MindSpore Insight
|
|
312
|
+
<https://www.mindspore.cn/mindinsight/docs/en/r2.2/index.html>`_ tool.
|
|
291
313
|
Now, Profiler supports AICORE operator, AICPU operator, HostCPU operator, memory,
|
|
292
314
|
correspondence, cluster, etc data analysis.
|
|
293
315
|
|
|
@@ -330,11 +352,16 @@ class Profiler:
|
|
|
330
352
|
Default value: ``True`` .
|
|
331
353
|
timeline_limit (int, optional): (Ascend/GPU) Set the maximum storage size of the timeline file (unit M).
|
|
332
354
|
When using this parameter, `op_time` must be set to True. Default value: ``500`` .
|
|
333
|
-
profile_framework (str, optional): (Ascend/GPU)
|
|
334
|
-
["all", "time", "memory", None]
|
|
355
|
+
profile_framework (str, optional): (Ascend/GPU) The host information to collect, it must be one of
|
|
356
|
+
["all", "time", "memory", None], When is not set to None, a subdirectory host_info will be generated in the
|
|
335
357
|
specified profiler directory, which stores the collected memory and time files on the Host side.
|
|
336
358
|
Default: "all".
|
|
337
359
|
|
|
360
|
+
- "all": Record both host timestamp and host memory usage.
|
|
361
|
+
- "time": Only record host timestamp.
|
|
362
|
+
- "memory": Only record host memory usage.
|
|
363
|
+
- None: Not record host information.
|
|
364
|
+
|
|
338
365
|
Raises:
|
|
339
366
|
RuntimeError: When the version of CANN does not match the version of MindSpore,
|
|
340
367
|
MindSpore cannot parse the generated ascend_job_id directory structure.
|
|
@@ -407,7 +434,6 @@ class Profiler:
|
|
|
407
434
|
self._rank_size = 1
|
|
408
435
|
self._rank_id = 0
|
|
409
436
|
self._ascend_profiler = None
|
|
410
|
-
self._ascend_msprof_exporter = None
|
|
411
437
|
self._timeline_size_limit_byte = 500 * 1024 * 1024 # 500MB
|
|
412
438
|
self._parallel_strategy = True
|
|
413
439
|
_environment_check()
|
|
@@ -424,6 +450,7 @@ class Profiler:
|
|
|
424
450
|
self._sync_enable = True
|
|
425
451
|
self._stop_time = 0
|
|
426
452
|
self._dynamic_status = False
|
|
453
|
+
self._model_iteration_dict = None
|
|
427
454
|
self._profile_framework = "all"
|
|
428
455
|
self._msprof_enable = os.getenv("PROFILER_SAMPLECONFIG")
|
|
429
456
|
if self._msprof_enable:
|
|
@@ -476,6 +503,25 @@ class Profiler:
|
|
|
476
503
|
|
|
477
504
|
return job_start_time
|
|
478
505
|
|
|
506
|
+
@staticmethod
|
|
507
|
+
def _parse_info_json(info_file):
|
|
508
|
+
"""
|
|
509
|
+
Parse info log file, get the rank id and device id of the job.
|
|
510
|
+
Args:
|
|
511
|
+
input_file (str): The file path of the parse info log file.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
rank id, device id
|
|
515
|
+
"""
|
|
516
|
+
with open(info_file, "r") as f:
|
|
517
|
+
info_dict = json.load(f)
|
|
518
|
+
|
|
519
|
+
rank_id = info_dict.get("rank_id", 0)
|
|
520
|
+
dev_info = info_dict.get("DeviceInfo", [])
|
|
521
|
+
dev_id = dev_info[0].get("id", -1)
|
|
522
|
+
|
|
523
|
+
return str(rank_id), str(dev_id)
|
|
524
|
+
|
|
479
525
|
def op_analyse(self, op_name, device_id=None):
|
|
480
526
|
"""
|
|
481
527
|
Profiler users can use this interface to obtain operator performance data.
|
|
@@ -487,8 +533,8 @@ class Profiler:
|
|
|
487
533
|
parse. If this interface is used for offline data parsing, Default: ``0`` .
|
|
488
534
|
|
|
489
535
|
Raises:
|
|
490
|
-
TypeError: If the op_name parameter type is incorrect.
|
|
491
|
-
TypeError: If the device_id parameter type is incorrect.
|
|
536
|
+
TypeError: If the `op_name` parameter type is incorrect.
|
|
537
|
+
TypeError: If the `device_id` parameter type is incorrect.
|
|
492
538
|
RuntimeError: If MindSpore runs on Ascend, this interface cannot be used.
|
|
493
539
|
|
|
494
540
|
Supported Platforms:
|
|
@@ -501,12 +547,12 @@ class Profiler:
|
|
|
501
547
|
>>> # Profiler init.
|
|
502
548
|
>>> profiler = Profiler()
|
|
503
549
|
>>> # Train Model or eval Model, taking LeNet5 as an example.
|
|
504
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.
|
|
550
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
505
551
|
>>> net = LeNet5()
|
|
506
552
|
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
507
553
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
508
554
|
>>> # Create the dataset taking MNIST as an example.
|
|
509
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.
|
|
555
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/mnist.py
|
|
510
556
|
>>> dataloader = create_dataset()
|
|
511
557
|
>>> model = Model(net, loss, optimizer)
|
|
512
558
|
>>> model.train(5, dataloader, dataset_sink_mode=False)
|
|
@@ -550,7 +596,22 @@ class Profiler:
|
|
|
550
596
|
Offline mode isused in abnormal exit scenario. This parameter should be set to ``None``
|
|
551
597
|
for online mode. Default: ``None``.
|
|
552
598
|
"""
|
|
599
|
+
self._analyse(offline_path=offline_path)
|
|
600
|
+
|
|
601
|
+
def _analyse(self, offline_path=None, model_iteration_dict=None):
|
|
602
|
+
"""
|
|
603
|
+
Collect and analyze training performance data, support calls during and after training. The example shows above.
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
offline_path (Union[str, None], optional): The data path which need to be analysed with offline mode.
|
|
607
|
+
Offline mode isused in abnormal exit scenario. This parameter should be set to ``None``
|
|
608
|
+
for online mode. Default: ``None``.
|
|
609
|
+
model_iteration_dict: Dictionary with model id as the key and iteration id as the value, Default: ``None``.
|
|
610
|
+
"""
|
|
611
|
+
self._model_iteration_dict = model_iteration_dict
|
|
553
612
|
if offline_path:
|
|
613
|
+
if self._is_offline_parser():
|
|
614
|
+
self._ascend_graph_analyse()
|
|
554
615
|
_offline_parse(offline_path)
|
|
555
616
|
return
|
|
556
617
|
if self._msprof_enable:
|
|
@@ -602,7 +663,7 @@ class Profiler:
|
|
|
602
663
|
Raises:
|
|
603
664
|
RuntimeError: If the profiler has already started.
|
|
604
665
|
RuntimeError: If MD profiling has stopped, repeated start action is not supported.
|
|
605
|
-
RuntimeError: If the start_profile parameter is not set or is set to True
|
|
666
|
+
RuntimeError: If the `start_profile` parameter is not set or is set to ``True``.
|
|
606
667
|
|
|
607
668
|
Examples:
|
|
608
669
|
>>> from mindspore.train import Callback
|
|
@@ -749,7 +810,6 @@ class Profiler:
|
|
|
749
810
|
|
|
750
811
|
if self._device_target == DeviceTarget.ASCEND.value:
|
|
751
812
|
self._ascend_profiler = c_expression.Profiler.get_instance("Ascend")
|
|
752
|
-
self._ascend_msprof_exporter = AscendMsprofExporter(self._output_path)
|
|
753
813
|
self._get_devid_rankid_and_devtarget()
|
|
754
814
|
|
|
755
815
|
def _init_profiler_info(self):
|
|
@@ -827,7 +887,6 @@ class Profiler:
|
|
|
827
887
|
# use context interface to open profiling, for the new mindspore version(after 2020.5.21)
|
|
828
888
|
self._ascend_profiler = c_expression.Profiler.get_instance("Ascend")
|
|
829
889
|
self._ascend_profiler.init(self._output_path, int(self._dev_id), self._ascend_profiling_options)
|
|
830
|
-
self._ascend_msprof_exporter = AscendMsprofExporter(self._output_path)
|
|
831
890
|
base_profiling_container_path = os.path.join(self._output_path, "container")
|
|
832
891
|
container_path = os.path.join(base_profiling_container_path, self._dev_id)
|
|
833
892
|
data_path = os.path.join(container_path, "data")
|
|
@@ -965,8 +1024,6 @@ class Profiler:
|
|
|
965
1024
|
else:
|
|
966
1025
|
logger.info("No need to stop profiler because profiler has been stopped.")
|
|
967
1026
|
# export op data before analyse
|
|
968
|
-
if self._op_time:
|
|
969
|
-
self._ascend_msprof_exporter.export(self._start_time, support_step_trace=False)
|
|
970
1027
|
self._ascend_graph_analyse()
|
|
971
1028
|
|
|
972
1029
|
def _minddata_analyse(self, source_path):
|
|
@@ -1040,8 +1097,11 @@ class Profiler:
|
|
|
1040
1097
|
aicpu_intermediate_detail_path = validate_and_normalize_path(aicpu_intermediate_detail_path)
|
|
1041
1098
|
framework_raw_path = validate_and_normalize_path(framework_raw_path)
|
|
1042
1099
|
|
|
1043
|
-
|
|
1044
|
-
|
|
1100
|
+
if context.get_context("mode") == context.GRAPH_MODE:
|
|
1101
|
+
output_timeline_data_path = os.path.join(self._output_path, f'output_timeline_data_{dev_id}.txt')
|
|
1102
|
+
output_timeline_data_path = validate_and_normalize_path(output_timeline_data_path)
|
|
1103
|
+
else:
|
|
1104
|
+
output_timeline_data_path = None
|
|
1045
1105
|
|
|
1046
1106
|
op_analyser = AscendOPGenerator(op_summary, op_statistic, dynamic_status)
|
|
1047
1107
|
op_analyser.parse()
|
|
@@ -1070,7 +1130,7 @@ class Profiler:
|
|
|
1070
1130
|
finally:
|
|
1071
1131
|
pass
|
|
1072
1132
|
|
|
1073
|
-
def _ascend_timeline_analyse(self,
|
|
1133
|
+
def _ascend_timeline_analyse(self, op_summary, steptrace):
|
|
1074
1134
|
"""Analyse timeline info."""
|
|
1075
1135
|
try:
|
|
1076
1136
|
logger.info("Profiling: analyzing the timeline data")
|
|
@@ -1142,6 +1202,7 @@ class Profiler:
|
|
|
1142
1202
|
if self._profile_communication and context.get_context("mode") == context.PYNATIVE_MODE:
|
|
1143
1203
|
logger.warning("[Profiler]The parameter profile_communication is not supported on Ascend "
|
|
1144
1204
|
"PyNative mode currently.")
|
|
1205
|
+
return
|
|
1145
1206
|
try:
|
|
1146
1207
|
logger.info("Profiling: analyzing the hccl profiler info.")
|
|
1147
1208
|
dev_id = self._rank_id if self._device_target == DeviceTarget.ASCEND.value else self._dev_id
|
|
@@ -1191,9 +1252,10 @@ class Profiler:
|
|
|
1191
1252
|
source_path = os.path.join(self._output_path, job_id)
|
|
1192
1253
|
self._minddata_analyse(source_path)
|
|
1193
1254
|
if self._op_time:
|
|
1255
|
+
_ascend_graph_msprof_generator(source_path, self._model_iteration_dict)
|
|
1194
1256
|
op_summary, op_statistic, steptrace = _ascend_graph_msprof_analyse(source_path)
|
|
1195
1257
|
self._ascend_op_analyse(op_summary, op_statistic, self._dynamic_status)
|
|
1196
|
-
self._ascend_timeline_analyse(
|
|
1258
|
+
self._ascend_timeline_analyse(op_summary, steptrace)
|
|
1197
1259
|
graph_ids = np.unique(op_summary['Model ID']).tolist()
|
|
1198
1260
|
points = self._ascend_fpbp_analyse(op_summary, steptrace)
|
|
1199
1261
|
if len(graph_ids) == 1:
|
|
@@ -1326,29 +1388,37 @@ class Profiler:
|
|
|
1326
1388
|
point_info_file_path = validate_and_normalize_path(point_info_file_path)
|
|
1327
1389
|
|
|
1328
1390
|
if self._device_target and self._device_target == DeviceTarget.GPU.value:
|
|
1329
|
-
|
|
1330
|
-
|
|
1331
|
-
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1391
|
+
if context.get_context("mode") != context.PYNATIVE_MODE:
|
|
1392
|
+
input_file_path = os.path.join(self._output_path, f'step_trace_profiling_{self._dev_id}.txt')
|
|
1393
|
+
input_file_path = validate_and_normalize_path(input_file_path)
|
|
1394
|
+
parser = GpuStepTraceParser(input_dir=input_file_path,
|
|
1395
|
+
output_file_path=step_trace_intermediate_file_path,
|
|
1396
|
+
is_training_mode=is_training_mode_flag,
|
|
1397
|
+
is_gpu_kernel_async_launch=is_gpu_kernel_async_launch_flag)
|
|
1398
|
+
parser.parse_and_save()
|
|
1399
|
+
point_info = parser.record_point_info(point_info_file_path)
|
|
1400
|
+
# print parser result
|
|
1401
|
+
parser.show()
|
|
1402
|
+
logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path)
|
|
1403
|
+
logger.info("The point info is: %s", point_info)
|
|
1404
|
+
|
|
1405
|
+
return point_info, is_training_mode_flag
|
|
1406
|
+
return {}, is_training_mode_flag
|
|
1407
|
+
|
|
1408
|
+
# whether keep the first step
|
|
1409
|
+
skip_first_step_flag = framework_parser.check_op_name(INIT_OP_NAME)
|
|
1410
|
+
# recognize inference or training mode
|
|
1411
|
+
is_training_mode_flag = framework_parser.check_op_name("Gradients")
|
|
1412
|
+
# parser the step trace files and save the result to disk
|
|
1413
|
+
source_path = validate_and_normalize_path(source_path)
|
|
1414
|
+
parser = AscendStepTraceParser(input_dir=source_path,
|
|
1415
|
+
output_file_path=step_trace_intermediate_file_path,
|
|
1416
|
+
skip_first_step=skip_first_step_flag,
|
|
1417
|
+
is_training_mode=is_training_mode_flag)
|
|
1418
|
+
parser.set_task_id_op_name_dict(framework_parser.to_task_id_full_op_name_dict())
|
|
1419
|
+
parser.parse_and_save()
|
|
1420
|
+
point_info = parser.record_point_info(point_info_file_path)
|
|
1421
|
+
|
|
1352
1422
|
# print parser result
|
|
1353
1423
|
parser.show()
|
|
1354
1424
|
logger.info("Finish saving the intermediate result: %s", step_trace_intermediate_file_path)
|
|
@@ -1393,11 +1463,10 @@ class Profiler:
|
|
|
1393
1463
|
return job_id
|
|
1394
1464
|
|
|
1395
1465
|
job_id = ""
|
|
1396
|
-
job_dirs = filter(lambda item: item.startswith('JOB') or item.startswith('PROF') and
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
reverse=True)
|
|
1466
|
+
job_dirs = filter(lambda item: item.startswith('JOB') or item.startswith('PROF') and os.path.isdir(
|
|
1467
|
+
os.path.join(self._output_path, item)), os.listdir(self._output_path))
|
|
1468
|
+
sorted_job_dirs = sorted(
|
|
1469
|
+
job_dirs, key=lambda x: os.path.getmtime(os.path.join(self._output_path, x)), reverse=True)
|
|
1401
1470
|
|
|
1402
1471
|
for dir_name in sorted_job_dirs:
|
|
1403
1472
|
if dir_name.startswith('PROF'):
|
|
@@ -1414,22 +1483,21 @@ class Profiler:
|
|
|
1414
1483
|
"profiler will ignore this job dir.", job_dir)
|
|
1415
1484
|
continue
|
|
1416
1485
|
|
|
1417
|
-
|
|
1486
|
+
info_file_path = get_file_path(job_dir, "info.json")
|
|
1487
|
+
if info_file_path is None:
|
|
1488
|
+
logger.warning("Find profiling job path %s, but info.json not exist, "
|
|
1489
|
+
"profiler will ignore this job dir.", job_dir)
|
|
1490
|
+
continue
|
|
1491
|
+
|
|
1492
|
+
_, training_device_id = self._parse_info_json(info_file_path)
|
|
1493
|
+
job_start_time = self._parse_start_log(start_file_path)
|
|
1494
|
+
|
|
1418
1495
|
if self._dev_id != training_device_id:
|
|
1419
1496
|
logger.debug("Find profiling find job path %s, but not current training device id. "
|
|
1420
1497
|
"Current training device id %s, but job path device id: %s, "
|
|
1421
1498
|
"profiler will ignore this job dir.", job_dir, self._dev_id, training_device_id)
|
|
1422
1499
|
continue
|
|
1423
1500
|
|
|
1424
|
-
if not os.listdir(os.path.join(job_dir, 'data')):
|
|
1425
|
-
continue
|
|
1426
|
-
|
|
1427
|
-
job_start_time = self._parse_start_log(start_file_path)
|
|
1428
|
-
if not job_start_time:
|
|
1429
|
-
logger.warning("Find profiling job path %s, but fail to get job start info, "
|
|
1430
|
-
"profiler will ignore this job dir.", job_start_time)
|
|
1431
|
-
continue
|
|
1432
|
-
|
|
1433
1501
|
if int(job_start_time) < self._start_time:
|
|
1434
1502
|
logger.warning("Find profiling job path %s, but start_time(%d) is earlier than this training "
|
|
1435
1503
|
"start_time(%d), profiler will ignore this job dir.",
|
|
@@ -1586,7 +1654,7 @@ class Profiler:
|
|
|
1586
1654
|
self._profile_framework = kwargs.pop("profile_framework", "all")
|
|
1587
1655
|
if self._profile_framework not in ["memory", "time", "all", None]:
|
|
1588
1656
|
logger.warning(f"For '{self.__class__.__name__}', the parameter profile_framework must be one of ['memory',"
|
|
1589
|
-
f" 'time', 'all', None]
|
|
1657
|
+
f" 'time', 'all', None], but got {self._profile_framework}, it will be set to 'all'.")
|
|
1590
1658
|
self._profile_framework = "all"
|
|
1591
1659
|
|
|
1592
1660
|
def _host_info_analyse(self):
|
mindspore/rewrite/api/node.py
CHANGED
|
@@ -14,12 +14,13 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: Node."""
|
|
16
16
|
|
|
17
|
-
from typing import Union, Optional
|
|
17
|
+
from typing import Union, Optional, List, Dict
|
|
18
|
+
from types import FunctionType
|
|
18
19
|
|
|
19
20
|
from mindspore.nn import Cell
|
|
20
21
|
from mindspore.ops.primitive import Primitive
|
|
21
22
|
from mindspore import _checkparam as Validator
|
|
22
|
-
from ..node import Node as NodeImpl
|
|
23
|
+
from ..node.node import Node as NodeImpl
|
|
23
24
|
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
|
24
25
|
from .node_type import NodeType
|
|
25
26
|
from .scoped_value import ScopedValue
|
|
@@ -50,8 +51,8 @@ class Node:
|
|
|
50
51
|
return self._node == other._node
|
|
51
52
|
|
|
52
53
|
@staticmethod
|
|
53
|
-
def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
54
|
-
kwargs:
|
|
54
|
+
def create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None,
|
|
55
|
+
kwargs: Dict[str, ScopedValue] = None, name: str = "", is_sub_net: bool = False) -> 'Node':
|
|
55
56
|
"""
|
|
56
57
|
Create a node. Only support create from a `Cell` now.
|
|
57
58
|
|
|
@@ -63,14 +64,15 @@ class Node:
|
|
|
63
64
|
|
|
64
65
|
Args:
|
|
65
66
|
cell (Cell): Cell-operator of this forward-layer.
|
|
66
|
-
targets (
|
|
67
|
-
|
|
67
|
+
targets (List[Union[ScopedValue, str]]): Indicate output names. Used as targets of an assign statement in
|
|
68
|
+
source code.
|
|
69
|
+
args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
68
70
|
source code. Default: ``None`` , which indicates the `cell` has no args inputs.
|
|
69
|
-
kwargs (
|
|
71
|
+
kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
70
72
|
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
71
73
|
code. Default: ``None`` , which indicates the `cell` has no kwargs inputs.
|
|
72
74
|
name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
|
|
73
|
-
generate name from `
|
|
75
|
+
generate name from `cell` when name is None. Rewrite will check and ensure the uniqueness of `name`
|
|
74
76
|
while node being inserted. Default: ``""`` .
|
|
75
77
|
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse
|
|
76
78
|
the `cell` to a TreeNode, otherwise the `cell` is parsed to a CallCell node. Default: ``False`` .
|
|
@@ -89,7 +91,7 @@ class Node:
|
|
|
89
91
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
90
92
|
>>> import mindspore.nn as nn
|
|
91
93
|
>>> # Define the network structure of LeNet5. Refer to
|
|
92
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
94
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
93
95
|
>>> net = LeNet5()
|
|
94
96
|
>>> stree = SymbolTree.create(net)
|
|
95
97
|
>>> node = stree.get_node("conv1")
|
|
@@ -108,8 +110,66 @@ class Node:
|
|
|
108
110
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
109
111
|
if kwargs is not None:
|
|
110
112
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
111
|
-
return Node(NodeImpl.create_call_op(cell, None, targets,
|
|
112
|
-
|
|
113
|
+
return Node(NodeImpl.create_call_op(cell, None, targets, args, kwargs, name, is_sub_net))
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]],
|
|
117
|
+
args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None) -> 'Node':
|
|
118
|
+
"""
|
|
119
|
+
Create a node that corresponds to a function call. The `function` object is saved into network, and used via
|
|
120
|
+
getting object from `self.` .
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
function (FunctionType): The function to be called.
|
|
124
|
+
targets (List[Union[ScopedValue, str]]): indicates output names. Used as targets of an assign statement in
|
|
125
|
+
source code.
|
|
126
|
+
args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
127
|
+
source code. Default: ``None`` , which indicates the `function` has no args inputs.
|
|
128
|
+
kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
129
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
130
|
+
code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
An instance of `Node`.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
TypeError: If `function` is not a `FunctionType`.
|
|
137
|
+
TypeError: If `targets` is not `list`.
|
|
138
|
+
TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
|
|
139
|
+
TypeError: If arg in `args` is not a `ScopedValue`.
|
|
140
|
+
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
|
|
141
|
+
|
|
142
|
+
Examples:
|
|
143
|
+
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
144
|
+
>>> import mindspore.nn as nn
|
|
145
|
+
>>> import mindspore.ops as ops
|
|
146
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
147
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
148
|
+
>>> net = LeNet5()
|
|
149
|
+
>>> stree = SymbolTree.create(net)
|
|
150
|
+
>>> node = stree.get_node("conv1")
|
|
151
|
+
>>> position = stree.after(node)
|
|
152
|
+
>>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
|
|
153
|
+
... args=[ScopedValue.create_naming_value('x')])
|
|
154
|
+
>>> stree.insert(position, new_node)
|
|
155
|
+
>>> print(new_node.get_node_type())
|
|
156
|
+
NodeType.CallFunction
|
|
157
|
+
"""
|
|
158
|
+
Validator.check_value_type("function", function, [FunctionType, type], "create_call_function")
|
|
159
|
+
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "create_call_function")
|
|
160
|
+
if args is not None:
|
|
161
|
+
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "create_call_function")
|
|
162
|
+
if kwargs is not None:
|
|
163
|
+
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "create_call_function")
|
|
164
|
+
return Node(NodeImpl._create_call_function(function, targets, args, kwargs))
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def create_input(param_name: str, default: Optional[ScopedValue] = None) -> 'Node':
|
|
168
|
+
# pylint: disable=missing-function-docstring
|
|
169
|
+
Validator.check_value_type("param_name", param_name, [str], "Node")
|
|
170
|
+
if default is not None:
|
|
171
|
+
Validator.check_value_type("default", default, [ScopedValue], "Node")
|
|
172
|
+
return Node(NodeImpl.create_input_node(None, param_name, default, name=f"input_{param_name}"))
|
|
113
173
|
|
|
114
174
|
def get_handler(self) -> NodeImpl:
|
|
115
175
|
return self._node
|
|
@@ -124,7 +184,7 @@ class Node:
|
|
|
124
184
|
Examples:
|
|
125
185
|
>>> from mindspore.rewrite import SymbolTree
|
|
126
186
|
>>> # Define the network structure of LeNet5. Refer to
|
|
127
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
187
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
128
188
|
>>> net = LeNet5()
|
|
129
189
|
>>> stree = SymbolTree.create(net)
|
|
130
190
|
>>> node = stree.get_node("conv2")
|
|
@@ -144,7 +204,7 @@ class Node:
|
|
|
144
204
|
Examples:
|
|
145
205
|
>>> from mindspore.rewrite import SymbolTree
|
|
146
206
|
>>> # Define the network structure of LeNet5. Refer to
|
|
147
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
207
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
148
208
|
>>> net = LeNet5()
|
|
149
209
|
>>> stree = SymbolTree.create(net)
|
|
150
210
|
>>> node = stree.get_node("conv1")
|
|
@@ -177,7 +237,7 @@ class Node:
|
|
|
177
237
|
Examples:
|
|
178
238
|
>>> from mindspore.rewrite import SymbolTree
|
|
179
239
|
>>> # Define the network structure of LeNet5. Refer to
|
|
180
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
240
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
181
241
|
>>> net = LeNet5()
|
|
182
242
|
>>> stree = SymbolTree.create(net)
|
|
183
243
|
>>> node = stree.get_node("relu_3")
|
|
@@ -216,7 +276,7 @@ class Node:
|
|
|
216
276
|
Examples:
|
|
217
277
|
>>> from mindspore.rewrite import SymbolTree
|
|
218
278
|
>>> # Define the network structure of LeNet5. Refer to
|
|
219
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
279
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
220
280
|
>>> net = LeNet5()
|
|
221
281
|
>>> stree = SymbolTree.create(net)
|
|
222
282
|
>>> src_node = stree.get_node("fc1")
|
|
@@ -256,7 +316,7 @@ class Node:
|
|
|
256
316
|
Examples:
|
|
257
317
|
>>> from mindspore.rewrite import SymbolTree
|
|
258
318
|
>>> # Define the network structure of LeNet5. Refer to
|
|
259
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
319
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
260
320
|
>>> net = LeNet5()
|
|
261
321
|
>>> stree = SymbolTree.create(net)
|
|
262
322
|
>>> node = stree.get_node("conv1")
|
|
@@ -276,7 +336,7 @@ class Node:
|
|
|
276
336
|
Examples:
|
|
277
337
|
>>> from mindspore.rewrite import SymbolTree
|
|
278
338
|
>>> # Define the network structure of LeNet5. Refer to
|
|
279
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
339
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
280
340
|
>>> net = LeNet5()
|
|
281
341
|
>>> stree = SymbolTree.create(net)
|
|
282
342
|
>>> node = stree.get_node("conv1")
|
|
@@ -303,7 +363,7 @@ class Node:
|
|
|
303
363
|
Examples:
|
|
304
364
|
>>> from mindspore.rewrite import SymbolTree
|
|
305
365
|
>>> # Define the network structure of LeNet5. Refer to
|
|
306
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
366
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
307
367
|
>>> net = LeNet5()
|
|
308
368
|
>>> stree = SymbolTree.create(net)
|
|
309
369
|
>>> node = stree.get_node("conv1")
|
|
@@ -326,7 +386,7 @@ class Node:
|
|
|
326
386
|
Examples:
|
|
327
387
|
>>> from mindspore.rewrite import SymbolTree
|
|
328
388
|
>>> # Define the network structure of LeNet5. Refer to
|
|
329
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.
|
|
389
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
330
390
|
>>> net = LeNet5()
|
|
331
391
|
>>> stree = SymbolTree.create(net)
|
|
332
392
|
>>> node = stree.get_node("conv1")
|
|
@@ -335,6 +395,29 @@ class Node:
|
|
|
335
395
|
"""
|
|
336
396
|
return self._node.get_args()
|
|
337
397
|
|
|
398
|
+
def get_symbol_tree(self) -> 'SymbolTree':
|
|
399
|
+
"""
|
|
400
|
+
Get the symbol tree which current node belongs to.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
SymbolTree, None if current node does not belong to any SymbolTree.
|
|
404
|
+
|
|
405
|
+
Examples:
|
|
406
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
407
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
408
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
409
|
+
>>> net = LeNet5()
|
|
410
|
+
>>> stree = SymbolTree.create(net)
|
|
411
|
+
>>> node = stree.get_node("conv1")
|
|
412
|
+
>>> print(type(node.get_symbol_tree()))
|
|
413
|
+
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
|
|
414
|
+
"""
|
|
415
|
+
from .symbol_tree import SymbolTree
|
|
416
|
+
stree_impl = self._node.get_belong_symbol_tree()
|
|
417
|
+
if not stree_impl:
|
|
418
|
+
return None
|
|
419
|
+
return SymbolTree(stree_impl)
|
|
420
|
+
|
|
338
421
|
def get_kwargs(self) -> {str: ScopedValue}:
|
|
339
422
|
return self._node.get_kwargs()
|
|
340
423
|
|