mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/mindspore/service.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
2
3
|
#
|
|
3
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
5
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,39 +12,42 @@
|
|
|
11
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
13
|
# See the License for the specific language governing permissions and
|
|
13
14
|
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
15
|
|
|
16
|
-
import os
|
|
17
16
|
import copy
|
|
18
17
|
import functools
|
|
18
|
+
import os
|
|
19
19
|
from collections import defaultdict
|
|
20
20
|
|
|
21
21
|
import mindspore as ms
|
|
22
|
-
from mindspore.common.tensor import Tensor
|
|
23
|
-
from mindspore import ops
|
|
24
22
|
from mindspore import nn
|
|
23
|
+
from mindspore.common.api import _no_grad
|
|
24
|
+
from mindspore.ops.primitive import Primitive
|
|
25
25
|
try:
|
|
26
26
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
27
|
-
pijit_label = True
|
|
28
27
|
except ImportError:
|
|
29
28
|
pijit_label = False
|
|
29
|
+
else:
|
|
30
|
+
pijit_label = True
|
|
30
31
|
|
|
31
|
-
|
|
32
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
|
|
33
|
+
from msprobe.core.common.file_utils import create_directory
|
|
34
|
+
from msprobe.core.common.utils import Const, print_tools_ends_info
|
|
32
35
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
36
|
+
from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,
|
|
37
|
+
ModuleBackwardInputs)
|
|
33
38
|
from msprobe.core.data_dump.scope import BaseScope
|
|
34
|
-
from msprobe.mindspore.
|
|
35
|
-
from msprobe.core.common.file_utils import create_directory
|
|
39
|
+
from msprobe.mindspore.cell_processor import CellProcessor
|
|
36
40
|
from msprobe.mindspore.common.log import logger
|
|
37
|
-
from msprobe.
|
|
38
|
-
|
|
41
|
+
from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs,
|
|
42
|
+
is_mindtorch, register_backward_hook_functions)
|
|
39
43
|
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
|
|
40
44
|
from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
|
|
41
|
-
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
|
|
42
|
-
ModuleBackwardInputs, ModuleBackwardOutputs
|
|
43
|
-
from msprobe.core.common.exceptions import MsprobeException
|
|
44
|
-
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
45
|
-
from msprobe.mindspore.cell_processor import CellProcessor
|
|
46
45
|
from msprobe.mindspore.dump.jit_dump import JitDump
|
|
46
|
+
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
|
|
47
|
+
from msprobe.mindspore.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
48
|
+
|
|
49
|
+
if is_mindtorch():
|
|
50
|
+
import torch
|
|
47
51
|
|
|
48
52
|
|
|
49
53
|
class Service:
|
|
@@ -55,75 +59,196 @@ class Service:
|
|
|
55
59
|
self.cell_processor = CellProcessor(self.data_collector.scope)
|
|
56
60
|
self.primitive_hook_service = PrimitiveHookService(self)
|
|
57
61
|
self.switch = False
|
|
62
|
+
self.inner_switch = False
|
|
58
63
|
self.primitive_switch = False
|
|
59
64
|
self.current_iter = 0
|
|
60
65
|
self.first_start = True
|
|
61
66
|
self.current_rank = None
|
|
62
67
|
self.dump_iter_dir = None
|
|
63
68
|
self.start_call = False
|
|
64
|
-
self.check_level_valid()
|
|
65
69
|
self.should_stop_service = False
|
|
70
|
+
self.params_grad_info = {}
|
|
71
|
+
# 提前注册,确保注册尽可能多的API hook
|
|
72
|
+
self.register_api_hook()
|
|
66
73
|
|
|
67
74
|
@staticmethod
|
|
68
|
-
def check_model_valid(
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
)
|
|
75
|
+
def check_model_valid(models):
|
|
76
|
+
target_module_type = (torch.nn.Module, "torch.nn.Module") if is_mindtorch() else (nn.Cell, "mindspore.nn.Cell")
|
|
77
|
+
if models is None or isinstance(models, target_module_type[0]):
|
|
78
|
+
return models
|
|
79
|
+
error_model = None
|
|
80
|
+
if isinstance(models, (list, tuple)):
|
|
81
|
+
for model in models:
|
|
82
|
+
if not isinstance(model, target_module_type[0]):
|
|
83
|
+
error_model = model
|
|
84
|
+
break
|
|
85
|
+
else:
|
|
86
|
+
error_model = models
|
|
74
87
|
|
|
75
|
-
|
|
76
|
-
|
|
88
|
+
if error_model is not None:
|
|
89
|
+
error_info = (f"The 'model' parameter must be a {target_module_type[1]} or list[{target_module_type[1]}] "
|
|
90
|
+
f"type, currently there is a {type(error_model)} type.")
|
|
77
91
|
raise MsprobeException(
|
|
78
|
-
MsprobeException.INVALID_PARAM_ERROR,
|
|
79
|
-
|
|
92
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
93
|
+
return models
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def prepare_module_input_output(target_type, cell, input_data, output):
|
|
97
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
98
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
|
|
99
|
+
else:
|
|
100
|
+
module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output)
|
|
101
|
+
return module_input_output
|
|
80
102
|
|
|
81
103
|
def build_hook(self, target_type, name):
|
|
82
|
-
def
|
|
83
|
-
if not self.
|
|
104
|
+
def pre_hook(api_or_cell_name, cell, input_data):
|
|
105
|
+
if not self.should_execute_hook(target_type, cell, True):
|
|
106
|
+
clean_input_kwargs(cell)
|
|
84
107
|
return None
|
|
85
108
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
109
|
+
with _no_grad():
|
|
110
|
+
self.inner_switch = True
|
|
111
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
112
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
113
|
+
else:
|
|
114
|
+
cell.forward_data_collected = True
|
|
115
|
+
HOOKCell.add_cell_count(name)
|
|
116
|
+
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None)
|
|
117
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
118
|
+
self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
119
|
+
self.inner_switch = False
|
|
120
|
+
return input_data
|
|
121
|
+
|
|
122
|
+
def grad_hook(cell, ori_name, param_name):
|
|
123
|
+
def hook_fn(grad):
|
|
124
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
125
|
+
return None
|
|
126
|
+
self.inner_switch = True
|
|
127
|
+
self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
|
|
128
|
+
self.inner_switch = False
|
|
129
|
+
return None
|
|
92
130
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
131
|
+
return hook_fn
|
|
132
|
+
|
|
133
|
+
def register_param_hook(ori_name, cell, params_dict):
|
|
134
|
+
'''
|
|
135
|
+
注册参数hook
|
|
136
|
+
'''
|
|
137
|
+
# data_mode为forward时,不注册参数hook
|
|
138
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
139
|
+
for param_name, param in params_dict.items():
|
|
140
|
+
if param.requires_grad:
|
|
141
|
+
param.register_hook(grad_hook(cell, ori_name, param_name))
|
|
142
|
+
|
|
143
|
+
def init_params_grad_info(cell, params_dict):
|
|
144
|
+
'''
|
|
145
|
+
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
146
|
+
'''
|
|
147
|
+
if not params_dict:
|
|
148
|
+
return
|
|
149
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
150
|
+
grad_name = cell.params_grad_name if hasattr(cell, 'params_grad_name') else None
|
|
151
|
+
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
152
|
+
if not self.params_grad_info.get(grad_name):
|
|
153
|
+
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
154
|
+
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
155
|
+
if data_info.get(grad_name):
|
|
156
|
+
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
157
|
+
self.data_collector.handle_data(grad_name, data_info,
|
|
158
|
+
flush=self.data_collector.data_processor.is_terminated)
|
|
159
|
+
# 记录当前模块的参数梯度信息已占位
|
|
160
|
+
self.params_grad_info[grad_name] = True
|
|
161
|
+
|
|
162
|
+
def forward_hook(api_or_cell_name, cell, input_data, output):
|
|
163
|
+
if not self.should_execute_hook(target_type, cell, True):
|
|
164
|
+
clean_input_kwargs(cell)
|
|
165
|
+
return None
|
|
166
|
+
with _no_grad():
|
|
167
|
+
self.inner_switch = True
|
|
168
|
+
module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output)
|
|
169
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
170
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
171
|
+
params_dict = {key.split(Const.SEP)[-1]: value for key, value in cell.parameters_dict(
|
|
172
|
+
recurse=False).items()}
|
|
173
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
174
|
+
# 判断是否需要注册参数hook
|
|
175
|
+
if not hasattr(cell, 'params_grad_name') and params_dict:
|
|
176
|
+
ori_name = api_or_cell_name.rsplit(Const.SEP, 2)[0]
|
|
177
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
178
|
+
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
179
|
+
setattr(cell, 'params_grad_name', grad_name)
|
|
180
|
+
register_param_hook(ori_name, cell, params_dict)
|
|
181
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
182
|
+
self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
183
|
+
init_params_grad_info(cell, params_dict)
|
|
184
|
+
else:
|
|
185
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
186
|
+
self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
187
|
+
|
|
188
|
+
if self.data_collector.if_return_forward_new_output():
|
|
189
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
190
|
+
self.inner_switch = False
|
|
191
|
+
return forward_new_output
|
|
192
|
+
clean_input_kwargs(cell)
|
|
193
|
+
self.inner_switch = False
|
|
194
|
+
return output
|
|
100
195
|
|
|
101
196
|
def backward_hook(api_or_cell_name, cell, grad_input, grad_output):
|
|
102
|
-
if not self.
|
|
197
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
103
198
|
return
|
|
199
|
+
self.inner_switch = True
|
|
104
200
|
|
|
201
|
+
need_exchange = True
|
|
105
202
|
if target_type == BaseScope.Module_Type_Module:
|
|
106
|
-
|
|
203
|
+
if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
|
|
204
|
+
need_exchange = False
|
|
205
|
+
api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
|
|
206
|
+
|
|
107
207
|
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
108
208
|
if self.data_collector:
|
|
109
209
|
# 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
|
|
110
|
-
|
|
210
|
+
if need_exchange:
|
|
211
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
212
|
+
else:
|
|
213
|
+
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
|
|
111
214
|
self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
|
|
215
|
+
self.inner_switch = False
|
|
216
|
+
|
|
217
|
+
def pre_backward_hook(api_or_cell_name, cell, grad_input):
|
|
218
|
+
if not self.should_execute_hook(target_type, cell, False):
|
|
219
|
+
return
|
|
220
|
+
self.inner_switch = True
|
|
221
|
+
module_input = ModuleBackwardInputs(grad_input=grad_input)
|
|
222
|
+
self.data_collector.update_api_or_module_name(api_or_cell_name)
|
|
223
|
+
self.data_collector.backward_input_data_collect(api_or_cell_name, cell, pid, module_input)
|
|
224
|
+
|
|
225
|
+
self.inner_switch = False
|
|
112
226
|
|
|
113
227
|
pid = os.getpid()
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
228
|
+
if target_type == BaseScope.Module_Type_Module:
|
|
229
|
+
full_forward_name = name + Const.FORWARD
|
|
230
|
+
full_backward_name = name + Const.BACKWARD
|
|
231
|
+
else:
|
|
232
|
+
full_forward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.FORWARD
|
|
233
|
+
full_backward_name = name + str(HOOKCell.get_cell_count(name)) + Const.SEP + Const.BACKWARD
|
|
234
|
+
pre_forward_hook = functools.partial(pre_hook, full_forward_name)
|
|
235
|
+
forward_hook = functools.partial(forward_hook, full_forward_name)
|
|
236
|
+
backward_hook = functools.partial(backward_hook, full_backward_name)
|
|
237
|
+
pre_backward_hook = functools.partial(pre_backward_hook, full_backward_name)
|
|
238
|
+
|
|
239
|
+
def wrap_pre_forward_hook(cell, input_data):
|
|
240
|
+
return pre_forward_hook(cell, input_data)
|
|
118
241
|
|
|
119
|
-
def wrap_forward_hook(cell,
|
|
120
|
-
return forward_hook(cell,
|
|
242
|
+
def wrap_forward_hook(cell, input_data, output_data):
|
|
243
|
+
return forward_hook(cell, input_data, output_data)
|
|
121
244
|
|
|
122
245
|
def wrap_backward_hook(cell, grad_input, grad_output):
|
|
123
246
|
return backward_hook(cell, grad_input, grad_output)
|
|
124
247
|
|
|
125
|
-
|
|
248
|
+
def wrap_pre_backward_hook(cell, grad_input):
|
|
249
|
+
return pre_backward_hook(cell, grad_input)
|
|
126
250
|
|
|
251
|
+
return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook, wrap_pre_backward_hook
|
|
127
252
|
|
|
128
253
|
def update_primitive_counters(self, primitive_name):
|
|
129
254
|
if primitive_name not in self.primitive_counters:
|
|
@@ -131,32 +256,20 @@ class Service:
|
|
|
131
256
|
else:
|
|
132
257
|
self.primitive_counters[primitive_name] += 1
|
|
133
258
|
|
|
134
|
-
def register_primitive_hooks(self):
|
|
135
|
-
primitive_set = set()
|
|
136
|
-
for _, cell in self.model.cells_and_names():
|
|
137
|
-
for pname, primitive in cell._primitives.items():
|
|
138
|
-
primitive_set.add((pname, primitive))
|
|
139
|
-
|
|
140
|
-
for pname, primitive in primitive_set:
|
|
141
|
-
NewPrimitive = type('NewPrimitive', (primitive.__class__,),
|
|
142
|
-
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__, pname)})
|
|
143
|
-
primitive.__class__ = NewPrimitive
|
|
144
|
-
|
|
145
259
|
def step(self):
|
|
260
|
+
if self.config.async_dump:
|
|
261
|
+
self.data_collector.fill_stack_tensor_data()
|
|
262
|
+
self.data_collector.data_processor.dump_async_data()
|
|
263
|
+
self.data_collector.write_json()
|
|
146
264
|
self.current_iter += 1
|
|
147
265
|
self.data_collector.update_iter(self.current_iter)
|
|
148
|
-
|
|
149
|
-
CellProcessor.reset_cell_stats()
|
|
150
|
-
self.primitive_hook_service.primitive_counters.clear()
|
|
151
|
-
self.data_collector.data_writer.reset_cache()
|
|
152
|
-
JitDump.jit_count = defaultdict(int)
|
|
266
|
+
self.reset_status()
|
|
153
267
|
|
|
154
268
|
def start(self, model=None):
|
|
155
269
|
self.start_call = True
|
|
156
270
|
if self.should_stop_service:
|
|
157
271
|
return
|
|
158
272
|
if self.need_end_service():
|
|
159
|
-
api_register.api_set_ori_func()
|
|
160
273
|
self.should_stop_service = True
|
|
161
274
|
self.switch = False
|
|
162
275
|
self.primitive_switch = False
|
|
@@ -176,7 +289,8 @@ class Service:
|
|
|
176
289
|
|
|
177
290
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
178
291
|
return
|
|
179
|
-
self.
|
|
292
|
+
self.register_primitive_hook()
|
|
293
|
+
self.register_cell_hook()
|
|
180
294
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
181
295
|
JitDump.set_config(self.config)
|
|
182
296
|
JitDump.set_data_collector(self.data_collector)
|
|
@@ -195,24 +309,6 @@ class Service:
|
|
|
195
309
|
logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
|
|
196
310
|
JitDump.jit_dump_switch = True
|
|
197
311
|
|
|
198
|
-
def forward_backward_dump_end(self):
|
|
199
|
-
if self.should_stop_service:
|
|
200
|
-
return
|
|
201
|
-
logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
|
|
202
|
-
if not self.start_call:
|
|
203
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
|
|
204
|
-
raise Exception("debugger.start() is not set in the current scope.")
|
|
205
|
-
if not self.switch:
|
|
206
|
-
logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
|
|
207
|
-
"debugger.start() and debugger.stop() ")
|
|
208
|
-
raise Exception("debugger.stop() is already called. ")
|
|
209
|
-
if self.config.step and self.current_iter not in self.config.step:
|
|
210
|
-
return
|
|
211
|
-
if self.config.rank and self.current_rank not in self.config.rank:
|
|
212
|
-
return
|
|
213
|
-
self.primitive_switch = False
|
|
214
|
-
api_register.api_set_ori_func()
|
|
215
|
-
|
|
216
312
|
def stop(self):
|
|
217
313
|
if self.should_stop_service:
|
|
218
314
|
return
|
|
@@ -228,6 +324,9 @@ class Service:
|
|
|
228
324
|
self.switch = False
|
|
229
325
|
self.primitive_switch = False
|
|
230
326
|
self.start_call = False
|
|
327
|
+
if self.config.async_dump:
|
|
328
|
+
self.data_collector.fill_stack_tensor_data()
|
|
329
|
+
self.data_collector.data_processor.dump_async_data()
|
|
231
330
|
self.data_collector.write_json()
|
|
232
331
|
JitDump.jit_dump_switch = False
|
|
233
332
|
|
|
@@ -238,8 +337,16 @@ class Service:
|
|
|
238
337
|
return True
|
|
239
338
|
return False
|
|
240
339
|
|
|
241
|
-
def
|
|
242
|
-
|
|
340
|
+
def should_execute_hook(self, hook_type, cell, is_forward):
|
|
341
|
+
is_cell_hook = hook_type == BaseScope.Module_Type_Module
|
|
342
|
+
if is_cell_hook and not self.switch:
|
|
343
|
+
return False
|
|
344
|
+
elif not is_cell_hook and is_forward and not self.switch:
|
|
345
|
+
return False
|
|
346
|
+
elif not is_cell_hook and not is_forward and not cell.forward_data_collected:
|
|
347
|
+
return False
|
|
348
|
+
|
|
349
|
+
if self.inner_switch:
|
|
243
350
|
return False
|
|
244
351
|
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
245
352
|
return False
|
|
@@ -249,6 +356,12 @@ class Service:
|
|
|
249
356
|
create_directory(self.config.dump_path)
|
|
250
357
|
self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}")
|
|
251
358
|
cur_rank = self.current_rank if self.current_rank is not None else ''
|
|
359
|
+
if self.config.level == Const.LEVEL_L2:
|
|
360
|
+
create_directory(self.dump_iter_dir)
|
|
361
|
+
kernel_config_path = create_kernel_config_json(self.dump_iter_dir, cur_rank)
|
|
362
|
+
self.config.kernel_config_path = kernel_config_path
|
|
363
|
+
return
|
|
364
|
+
|
|
252
365
|
dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}")
|
|
253
366
|
create_directory(dump_dir)
|
|
254
367
|
if self.config.task in self.data_collector.tasks_need_tensor_data:
|
|
@@ -261,37 +374,96 @@ class Service:
|
|
|
261
374
|
stack_file_path = os.path.join(dump_dir, "stack.json")
|
|
262
375
|
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
263
376
|
self.data_collector.update_dump_paths(
|
|
264
|
-
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
|
|
377
|
+
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None
|
|
378
|
+
)
|
|
379
|
+
self.data_collector.initialize_json_file(
|
|
380
|
+
framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK
|
|
381
|
+
)
|
|
265
382
|
|
|
266
383
|
def empty(self, *args, **kwargs):
|
|
267
384
|
pass
|
|
268
385
|
|
|
269
|
-
def
|
|
270
|
-
|
|
271
|
-
|
|
386
|
+
def register_api_hook(self):
|
|
387
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
388
|
+
logger.info(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
272
389
|
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
|
|
273
390
|
api_register.api_set_hook_func()
|
|
274
|
-
if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
|
|
275
|
-
self.register_primitive_hooks()
|
|
276
391
|
|
|
392
|
+
def get_cells_and_names(self):
|
|
393
|
+
cells_and_names_with_index = {}
|
|
394
|
+
|
|
395
|
+
def get_cell_or_module(model):
|
|
396
|
+
return model.named_modules() if is_mindtorch() else model.cells_and_names()
|
|
397
|
+
|
|
398
|
+
if isinstance(self.model, (list, tuple)):
|
|
399
|
+
for index, model in enumerate(self.model):
|
|
400
|
+
cells_and_names_with_index[str(index)] = get_cell_or_module(model)
|
|
401
|
+
else:
|
|
402
|
+
cells_and_names_with_index["-1"] = get_cell_or_module(self.model)
|
|
403
|
+
return cells_and_names_with_index
|
|
404
|
+
|
|
405
|
+
def register_primitive_hook(self):
|
|
406
|
+
if self.config.level not in [Const.LEVEL_MIX, Const.LEVEL_L1]:
|
|
407
|
+
return
|
|
408
|
+
if not self.model or self.config.task not in Const.DUMP_DATA_COLLECTION_LIST:
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
primitive_set = set()
|
|
412
|
+
cells_and_names_with_index = self.get_cells_and_names()
|
|
413
|
+
for cells_and_names in cells_and_names_with_index.values():
|
|
414
|
+
for _, cell in cells_and_names:
|
|
415
|
+
for attribute, value in vars(cell).items():
|
|
416
|
+
if isinstance(value, Primitive):
|
|
417
|
+
primitive_set.add((attribute, value))
|
|
418
|
+
|
|
419
|
+
for pname, primitive in primitive_set:
|
|
420
|
+
primitive_class_name = primitive.__class__.__name__
|
|
421
|
+
primitive_combined_name = pname + Const.SEP + primitive_class_name
|
|
422
|
+
new_primitive = type('NewPrimitive', (primitive.__class__,),
|
|
423
|
+
{'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
|
|
424
|
+
primitive_combined_name)})
|
|
425
|
+
primitive.__class__ = new_primitive
|
|
426
|
+
|
|
427
|
+
def register_cell_hook(self):
|
|
277
428
|
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
|
|
429
|
+
logger.info(f"The cell {self.config.task} hook function is successfully mounted to the model.")
|
|
278
430
|
if not self.model:
|
|
279
431
|
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
|
|
280
432
|
f"The current level is {self.config.level}, the model cannot be None")
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
433
|
+
model_type = Const.MODULE if is_mindtorch() else Const.CELL
|
|
434
|
+
cells_and_names_with_index = self.get_cells_and_names()
|
|
435
|
+
|
|
436
|
+
for index, cells_and_names in cells_and_names_with_index.items():
|
|
437
|
+
model = self.model if index == "-1" else self.model[int(index)]
|
|
438
|
+
for name, cell in cells_and_names:
|
|
439
|
+
if cell == model:
|
|
440
|
+
continue
|
|
441
|
+
cell_index = (index + Const.SEP) if index != "-1" else ""
|
|
442
|
+
prefix = (model_type + Const.SEP + cell_index + name +
|
|
443
|
+
Const.SEP + cell.__class__.__name__ + Const.SEP)
|
|
444
|
+
_, forward_hook, backward_hook, _ = self.build_hook(BaseScope.Module_Type_Module, prefix)
|
|
445
|
+
cell.register_forward_hook(forward_hook)
|
|
446
|
+
cell.register_forward_pre_hook(
|
|
447
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
448
|
+
cell.register_forward_hook(
|
|
449
|
+
self.cell_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
450
|
+
|
|
451
|
+
register_backward_hook_functions["full"](cell, backward_hook)
|
|
452
|
+
register_backward_hook_functions["pre"](
|
|
453
|
+
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
454
|
+
register_backward_hook_functions["full"](
|
|
455
|
+
cell, self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
456
|
+
|
|
457
|
+
def reset_status(self):
|
|
458
|
+
self.primitive_hook_service.primitive_counters.clear()
|
|
459
|
+
self.data_collector.data_writer.reset_cache()
|
|
460
|
+
JitDump.jit_count = defaultdict(int)
|
|
461
|
+
self.params_grad_info.clear()
|
|
462
|
+
|
|
463
|
+
if self.config.level == Const.LEVEL_L2:
|
|
464
|
+
self.data_collector.data_processor.reset_status()
|
|
465
|
+
return
|
|
466
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
467
|
+
return
|
|
468
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
469
|
+
return
|
|
@@ -1,3 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.const import Const
|
|
2
17
|
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
|
|
3
18
|
from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
|
msprobe/msprobe.py
CHANGED
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
import argparse
|
|
17
17
|
import sys
|
|
18
18
|
import importlib.util
|
|
19
|
-
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
20
21
|
from msprobe.core.common.log import logger
|
|
22
|
+
from msprobe.core.compare.utils import _compare_parser
|
|
21
23
|
from msprobe.core.compare.compare_cli import compare_cli
|
|
22
|
-
from msprobe.core.
|
|
24
|
+
from msprobe.core.compare.merge_result.merge_result_cli import _merge_result_parser, merge_result_cli
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
def is_module_available(module_name):
|
|
@@ -45,10 +47,20 @@ def main():
|
|
|
45
47
|
multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
|
|
46
48
|
api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
|
|
47
49
|
run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
|
|
50
|
+
code_mapping_cmd_parser = subparsers.add_parser('code_mapping')
|
|
51
|
+
graph_service_cmd_parser = subparsers.add_parser('graph')
|
|
52
|
+
op_generate_cmd_parser = subparsers.add_parser('op_generate')
|
|
53
|
+
merge_result_parser = subparsers.add_parser('merge_result')
|
|
48
54
|
_compare_parser(compare_cmd_parser)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
55
|
+
_merge_result_parser(merge_result_parser)
|
|
56
|
+
|
|
57
|
+
is_torch_available = is_module_available("torch")
|
|
58
|
+
|
|
59
|
+
if len(sys.argv) < 4:
|
|
60
|
+
parser.print_help()
|
|
61
|
+
sys.exit(0)
|
|
62
|
+
framework_args = parser.parse_args(sys.argv[1:3])
|
|
63
|
+
if framework_args.framework == Const.PT_FRAMEWORK:
|
|
52
64
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
|
|
53
65
|
from msprobe.pytorch.parse_tool.cli import parse as cli_parse
|
|
54
66
|
from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
|
|
@@ -56,20 +68,29 @@ def main():
|
|
|
56
68
|
_api_precision_compare_command
|
|
57
69
|
from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
|
|
58
70
|
_run_overflow_check_command
|
|
71
|
+
from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
|
|
72
|
+
from msprobe.pytorch.api_accuracy_checker.generate_op_script.op_generator import _op_generator_parser, \
|
|
73
|
+
_run_operator_generate_commond
|
|
59
74
|
|
|
60
75
|
_run_ut_parser(run_ut_cmd_parser)
|
|
61
76
|
_run_ut_parser(multi_run_ut_cmd_parser)
|
|
62
77
|
multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
63
|
-
|
|
78
|
+
help='Number of splits for parallel processing. Range: 1-64')
|
|
64
79
|
_api_precision_compare_parser(api_precision_compare_cmd_parser)
|
|
65
80
|
_run_overflow_check_parser(run_overflow_check_cmd_parser)
|
|
66
|
-
|
|
81
|
+
_pt_graph_service_parser(graph_service_cmd_parser)
|
|
82
|
+
_op_generator_parser(op_generate_cmd_parser)
|
|
83
|
+
elif framework_args.framework == Const.MS_FRAMEWORK:
|
|
67
84
|
from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
|
|
85
|
+
from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
|
|
68
86
|
add_api_accuracy_checker_argument(run_ut_cmd_parser)
|
|
87
|
+
from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
|
|
88
|
+
multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
|
|
89
|
+
from msprobe.mindspore.code_mapping.cmd_parser import add_ir_parser_arguments
|
|
90
|
+
add_ir_parser_arguments(code_mapping_cmd_parser)
|
|
91
|
+
|
|
92
|
+
_ms_graph_service_parser(graph_service_cmd_parser)
|
|
69
93
|
|
|
70
|
-
if len(sys.argv) == 1:
|
|
71
|
-
parser.print_help()
|
|
72
|
-
sys.exit(0)
|
|
73
94
|
args = parser.parse_args(sys.argv[1:])
|
|
74
95
|
if sys.argv[2] == Const.PT_FRAMEWORK:
|
|
75
96
|
if not is_torch_available:
|
|
@@ -86,20 +107,37 @@ def main():
|
|
|
86
107
|
_api_precision_compare_command(args)
|
|
87
108
|
elif sys.argv[3] == "run_overflow_check":
|
|
88
109
|
_run_overflow_check_command(args)
|
|
110
|
+
elif sys.argv[3] == "graph":
|
|
111
|
+
_pt_graph_service_command(args)
|
|
112
|
+
elif sys.argv[3] == 'op_generate':
|
|
113
|
+
_run_operator_generate_commond(args)
|
|
89
114
|
elif sys.argv[3] == "compare":
|
|
90
115
|
if args.cell_mapping is not None or args.api_mapping is not None:
|
|
91
116
|
logger.error("Argument -cm or -am is not supported in PyTorch framework")
|
|
92
117
|
raise Exception("Argument -cm or -am is not supported in PyTorch framework")
|
|
93
118
|
compare_cli(args)
|
|
119
|
+
elif sys.argv[3] == "merge_result":
|
|
120
|
+
merge_result_cli(args)
|
|
94
121
|
else:
|
|
95
122
|
if not is_module_available(Const.MS_FRAMEWORK):
|
|
96
123
|
logger.error("MindSpore does not exist, please install MindSpore library")
|
|
97
124
|
raise Exception("MindSpore does not exist, please install MindSpore library")
|
|
98
125
|
if sys.argv[3] == "compare":
|
|
99
126
|
compare_cli(args)
|
|
127
|
+
elif sys.argv[3] == "merge_result":
|
|
128
|
+
merge_result_cli(args)
|
|
100
129
|
elif sys.argv[3] == "run_ut":
|
|
101
130
|
from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
|
|
102
131
|
api_checker_main(args)
|
|
132
|
+
elif sys.argv[3] == "multi_run_ut":
|
|
133
|
+
from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main
|
|
134
|
+
mul_api_checker_main(args)
|
|
135
|
+
elif sys.argv[3] == "graph":
|
|
136
|
+
_ms_graph_service_command(args)
|
|
137
|
+
elif sys.argv[3] == "code_mapping":
|
|
138
|
+
from msprobe.mindspore.code_mapping.main import code_mapping_main
|
|
139
|
+
code_mapping_main(args)
|
|
140
|
+
|
|
103
141
|
|
|
104
142
|
if __name__ == "__main__":
|
|
105
143
|
main()
|