mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,17 +12,18 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import argparse
|
|
18
17
|
import os
|
|
19
18
|
from collections import namedtuple
|
|
20
19
|
|
|
20
|
+
from msprobe.core.common.file_utils import create_directory
|
|
21
|
+
from msprobe.pytorch.parse_tool.lib.compare import Compare
|
|
21
22
|
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
23
|
+
from msprobe.pytorch.parse_tool.lib.parse_exception import catch_exception, ParseException
|
|
22
24
|
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
23
|
-
from msprobe.pytorch.parse_tool.lib.compare import Compare
|
|
24
25
|
from msprobe.pytorch.parse_tool.lib.visualization import Visualization
|
|
25
|
-
|
|
26
|
-
from msprobe.core.common.file_utils import create_directory
|
|
26
|
+
|
|
27
27
|
|
|
28
28
|
class ParseTool:
|
|
29
29
|
def __init__(self):
|
|
@@ -117,7 +117,8 @@ class ParseTool:
|
|
|
117
117
|
self.util.check_path_valid(args.golden_dump_path)
|
|
118
118
|
self.util.check_file_path_format(args.my_dump_path, Const.NPY_SUFFIX)
|
|
119
119
|
self.util.check_file_path_format(args.golden_dump_path, Const.NPY_SUFFIX)
|
|
120
|
-
compare_data_args = namedtuple('compare_data_args',
|
|
120
|
+
compare_data_args = namedtuple('compare_data_args',
|
|
121
|
+
['my_dump_path', 'golden_dump_path', 'save', 'rtol', 'atol', 'count'])
|
|
121
122
|
compare_data_args.__new__.__defaults__ = (False, 0.001, 0.001, 20)
|
|
122
123
|
res = compare_data_args(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count)
|
|
123
124
|
self.compare.compare_data(res)
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,24 +12,24 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
16
|
+
import hashlib
|
|
17
17
|
import os
|
|
18
18
|
import re
|
|
19
|
-
import sys
|
|
20
19
|
import subprocess
|
|
21
|
-
import
|
|
20
|
+
import sys
|
|
22
21
|
import time
|
|
23
|
-
import numpy as np
|
|
24
22
|
from collections import namedtuple
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
28
|
-
from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\
|
|
29
|
-
check_path_executable, check_path_owner_consistent
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
30
25
|
from msprobe.core.common.const import FileCheckConst
|
|
26
|
+
from msprobe.core.common.file_utils import change_mode, check_other_user_writable, \
|
|
27
|
+
check_path_executable, check_path_owner_consistent
|
|
31
28
|
from msprobe.core.common.file_utils import check_file_or_directory_path, remove_path, check_file_type, os_walk_for_files
|
|
32
29
|
from msprobe.pytorch.common.log import logger
|
|
33
|
-
|
|
30
|
+
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
31
|
+
from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
|
|
32
|
+
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
34
33
|
|
|
35
34
|
try:
|
|
36
35
|
from rich.traceback import install
|
|
@@ -135,7 +134,7 @@ class Util:
|
|
|
135
134
|
zero_mask = (data == 0)
|
|
136
135
|
data[zero_mask] += np.finfo(float).eps
|
|
137
136
|
return data
|
|
138
|
-
|
|
137
|
+
|
|
139
138
|
@staticmethod
|
|
140
139
|
def dir_contains_only(path, endfix):
|
|
141
140
|
files = os_walk_for_files(path, Const.MAX_TRAVERSAL_DEPTH)
|
|
@@ -143,11 +142,11 @@ class Util:
|
|
|
143
142
|
if not file['file'].endswith(endfix):
|
|
144
143
|
return False
|
|
145
144
|
return True
|
|
146
|
-
|
|
145
|
+
|
|
147
146
|
@staticmethod
|
|
148
147
|
def localtime_str():
|
|
149
148
|
return time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
150
|
-
|
|
149
|
+
|
|
151
150
|
@staticmethod
|
|
152
151
|
def change_filemode_safe(path):
|
|
153
152
|
change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
@@ -208,7 +207,7 @@ class Util:
|
|
|
208
207
|
|
|
209
208
|
def list_numpy_files(self, path, extern_pattern=''):
|
|
210
209
|
return self.list_file_with_pattern(path, Const.NUMPY_PATTERN, extern_pattern,
|
|
211
|
-
|
|
210
|
+
self._gen_numpy_file_info)
|
|
212
211
|
|
|
213
212
|
def create_columns(self, content):
|
|
214
213
|
if not Columns:
|
|
@@ -1,8 +1,7 @@
|
|
|
1
|
-
|
|
2
|
-
#
|
|
3
|
-
|
|
4
|
-
#
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
5
|
# you may not use this file except in compliance with the License.
|
|
7
6
|
# You may obtain a copy of the License at
|
|
8
7
|
#
|
|
@@ -13,14 +12,14 @@
|
|
|
13
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
13
|
# See the License for the specific language governing permissions and
|
|
15
14
|
# limitations under the License.
|
|
16
|
-
|
|
15
|
+
|
|
17
16
|
import json
|
|
18
|
-
import numpy as np
|
|
19
17
|
|
|
18
|
+
import numpy as np
|
|
19
|
+
from msprobe.core.common.file_utils import FileOpen, load_npy, save_npy_to_txt
|
|
20
20
|
from msprobe.pytorch.parse_tool.lib.config import Const
|
|
21
|
-
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
22
21
|
from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
|
|
23
|
-
from msprobe.
|
|
22
|
+
from msprobe.pytorch.parse_tool.lib.utils import Util
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class Visualization:
|
|
@@ -77,7 +76,7 @@ class Visualization:
|
|
|
77
76
|
self.util.log.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2]))
|
|
78
77
|
self.util.log.info(" {}".format(item[3]))
|
|
79
78
|
continue
|
|
80
|
-
if len(msg) > 5 and len(msg[5]) >=
|
|
79
|
+
if len(msg) > 5 and len(msg[5]) >= 3:
|
|
81
80
|
summery_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \
|
|
82
81
|
.format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2])
|
|
83
82
|
if not title_printed:
|
msprobe/pytorch/service.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -19,7 +19,7 @@ from collections import namedtuple
|
|
|
19
19
|
|
|
20
20
|
import torch
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
22
|
-
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
22
|
+
from msprobe.core.common.exceptions import DistributedNotInitializedError
|
|
23
23
|
from msprobe.core.common.file_utils import create_directory
|
|
24
24
|
from msprobe.core.common.utils import print_tools_ends_info
|
|
25
25
|
from msprobe.core.data_dump.data_collector import build_data_collector
|
|
@@ -29,10 +29,10 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
|
|
|
29
29
|
from msprobe.pytorch.common.log import logger
|
|
30
30
|
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
31
31
|
from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json
|
|
32
|
-
from msprobe.pytorch.
|
|
32
|
+
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
33
33
|
from msprobe.pytorch.hook_module.api_registry import api_register
|
|
34
34
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
35
|
-
from msprobe.pytorch.
|
|
35
|
+
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
36
36
|
|
|
37
37
|
torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
|
|
38
38
|
if torch_version_above_or_equal_2:
|
|
@@ -48,100 +48,175 @@ class Service:
|
|
|
48
48
|
self.data_collector = build_data_collector(config)
|
|
49
49
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
50
50
|
self.switch = False
|
|
51
|
+
self.inner_switch = False
|
|
51
52
|
self.current_iter = 0
|
|
52
53
|
self.first_start = True
|
|
53
54
|
self.current_rank = None
|
|
54
55
|
self.dump_iter_dir = None
|
|
55
56
|
self.should_stop_service = False
|
|
56
57
|
self.attl = None
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
logger.info_on_rank_0("Data needed ends here.")
|
|
61
|
-
api_register.api_originality()
|
|
62
|
-
|
|
63
|
-
@staticmethod
|
|
64
|
-
def is_registered_backward_hook(module):
|
|
65
|
-
if hasattr(module, '_backward_hooks') and \
|
|
66
|
-
len(module._backward_hooks) > 0 and \
|
|
67
|
-
module._is_full_backward_hook is False:
|
|
68
|
-
return True
|
|
69
|
-
return False
|
|
70
|
-
|
|
71
|
-
def check_register_full_backward_hook(self, module):
|
|
72
|
-
if self.is_registered_backward_hook(module):
|
|
73
|
-
module._backward_hooks.clear()
|
|
74
|
-
module._is_full_backward_hook = None
|
|
75
|
-
logger.warning("Found deprecated backward hooks. Removing them and switching to full backward hooks.")
|
|
58
|
+
self.params_grad_info = {}
|
|
59
|
+
# 提前注册,确保注册尽可能多的API hook
|
|
60
|
+
self.register_api_hook()
|
|
76
61
|
|
|
77
62
|
def build_hook(self, module_type, name):
|
|
78
63
|
def pre_hook(api_or_module_name, module, args, kwargs):
|
|
79
|
-
if not self.should_execute_hook():
|
|
64
|
+
if not self.should_execute_hook(module_type, module, True):
|
|
80
65
|
return args, kwargs
|
|
81
66
|
|
|
67
|
+
self.inner_switch = True
|
|
82
68
|
if module_type == BaseScope.Module_Type_Module:
|
|
83
|
-
api_or_module_name = module.mindstudio_reserved_name
|
|
69
|
+
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
70
|
+
else:
|
|
71
|
+
module.forward_data_collected = True
|
|
72
|
+
HOOKModule.add_module_count(name)
|
|
84
73
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
85
74
|
|
|
86
75
|
if self.config.online_run_ut:
|
|
76
|
+
self.inner_switch = False
|
|
87
77
|
return None, None
|
|
88
78
|
if self.data_collector:
|
|
89
79
|
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None)
|
|
90
|
-
self.data_collector.
|
|
80
|
+
self.data_collector.forward_input_data_collect(api_or_module_name, module, pid, module_input_output)
|
|
81
|
+
|
|
82
|
+
self.inner_switch = False
|
|
91
83
|
return args, kwargs
|
|
92
84
|
|
|
85
|
+
def grad_hook(module, ori_name, param_name):
|
|
86
|
+
def hook_fn(grad):
|
|
87
|
+
if not self.should_execute_hook(module_type, module, False):
|
|
88
|
+
return grad
|
|
89
|
+
self.inner_switch = True
|
|
90
|
+
self.data_collector.params_data_collect(ori_name, param_name, pid, grad)
|
|
91
|
+
self.inner_switch = False
|
|
92
|
+
return grad
|
|
93
|
+
|
|
94
|
+
return hook_fn
|
|
95
|
+
|
|
96
|
+
def register_param_hook(ori_name, module, params_dict):
|
|
97
|
+
'''
|
|
98
|
+
注册参数hook
|
|
99
|
+
'''
|
|
100
|
+
# data_mode为forward时,不注册参数hook
|
|
101
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
102
|
+
for param_name, param in params_dict.items():
|
|
103
|
+
if param.requires_grad:
|
|
104
|
+
param.register_hook(grad_hook(module, ori_name, param_name))
|
|
105
|
+
|
|
106
|
+
def init_params_grad_info(module, params_dict):
|
|
107
|
+
'''
|
|
108
|
+
初始化参数梯度信息, 在前向hook结束后, 将参数梯度信息写入cache_data中用于占位
|
|
109
|
+
'''
|
|
110
|
+
if not params_dict:
|
|
111
|
+
return
|
|
112
|
+
if not (Const.FORWARD in self.config.data_mode and Const.BACKWARD not in self.config.data_mode):
|
|
113
|
+
grad_name = module.params_grad_name if hasattr(module, 'params_grad_name') else None
|
|
114
|
+
# 判断是否已经在cache_data中进行了占位, 若没有则先写入cache_data中
|
|
115
|
+
if not self.params_grad_info.get(grad_name):
|
|
116
|
+
data_info = {grad_name: {key: [None] for key, value in params_dict.items() if value.requires_grad}}
|
|
117
|
+
# 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位
|
|
118
|
+
if data_info.get(grad_name):
|
|
119
|
+
# 将grad_name的data_info先写入cache_data中, 梯度计算后再更新
|
|
120
|
+
self.data_collector.handle_data(grad_name, data_info,
|
|
121
|
+
flush=self.data_collector.data_processor.is_terminated)
|
|
122
|
+
# 记录当前模块的参数梯度信息已占位
|
|
123
|
+
self.params_grad_info[grad_name] = True
|
|
124
|
+
|
|
93
125
|
def forward_hook(api_or_module_name, module, args, kwargs, output):
|
|
94
|
-
if not self.should_execute_hook():
|
|
126
|
+
if not self.should_execute_hook(module_type, module, True):
|
|
95
127
|
return None
|
|
96
128
|
|
|
97
|
-
|
|
98
|
-
api_or_module_name = module.mindstudio_reserved_name
|
|
99
|
-
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
100
|
-
|
|
129
|
+
self.inner_switch = True
|
|
101
130
|
if self.config.online_run_ut:
|
|
131
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
102
132
|
if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name):
|
|
103
133
|
return None
|
|
104
|
-
api_data = ApiData(
|
|
134
|
+
api_data = ApiData(
|
|
135
|
+
api_or_module_name[:-len(Const.FORWARD_NAME_SUFFIX)],
|
|
136
|
+
args,
|
|
137
|
+
kwargs,
|
|
138
|
+
output,
|
|
139
|
+
self.current_iter,
|
|
140
|
+
self.current_rank
|
|
141
|
+
)
|
|
105
142
|
self.attl_send(api_data)
|
|
143
|
+
self.inner_switch = False
|
|
106
144
|
return None
|
|
107
145
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
146
|
+
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output)
|
|
147
|
+
if module_type == BaseScope.Module_Type_Module:
|
|
148
|
+
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
149
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
150
|
+
params_dict = {key.split(Const.SEP)[-1]: value for key, value in module.named_parameters(recurse=False)}
|
|
151
|
+
setattr(module_input_output, Const.PARAMS, params_dict)
|
|
152
|
+
# 判断是否需要注册参数hook
|
|
153
|
+
if not hasattr(module, 'params_grad_name') and params_dict:
|
|
154
|
+
ori_name = api_or_module_name.rsplit(Const.SEP, 2)[0]
|
|
155
|
+
grad_name = ori_name + Const.SEP + Const.PARAMS_GRAD
|
|
156
|
+
# 首次执行前向hook时,添加params_grad_name属性,并注册参数hook
|
|
157
|
+
setattr(module, 'params_grad_name', grad_name)
|
|
158
|
+
register_param_hook(ori_name, module, params_dict)
|
|
159
|
+
self.data_collector.forward_data_collect(
|
|
160
|
+
api_or_module_name,
|
|
161
|
+
module,
|
|
162
|
+
pid,
|
|
163
|
+
module_input_output
|
|
164
|
+
)
|
|
165
|
+
init_params_grad_info(module, params_dict)
|
|
166
|
+
else:
|
|
167
|
+
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
168
|
+
self.data_collector.forward_output_data_collect(
|
|
169
|
+
api_or_module_name,
|
|
170
|
+
module,
|
|
171
|
+
pid,
|
|
172
|
+
module_input_output
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if self.data_collector.if_return_forward_new_output():
|
|
176
|
+
forward_new_output = self.data_collector.get_forward_new_output()
|
|
177
|
+
self.inner_switch = False
|
|
178
|
+
return forward_new_output
|
|
179
|
+
self.inner_switch = False
|
|
113
180
|
return output
|
|
114
181
|
|
|
115
182
|
def forward_hook_torch_version_below_2(api_or_module_name, module, args, output):
|
|
116
183
|
return forward_hook(api_or_module_name, module, args, {}, output)
|
|
117
184
|
|
|
118
185
|
def backward_hook(api_or_module_name, module, grad_input, grad_output):
|
|
119
|
-
if not self.should_execute_hook():
|
|
186
|
+
if not self.should_execute_hook(module_type, module, False):
|
|
120
187
|
return
|
|
121
188
|
|
|
189
|
+
self.inner_switch = True
|
|
122
190
|
if module_type == BaseScope.Module_Type_Module:
|
|
123
|
-
api_or_module_name = module.mindstudio_reserved_name
|
|
191
|
+
api_or_module_name = module.mindstudio_reserved_name[-1]
|
|
124
192
|
self.data_collector.update_api_or_module_name(api_or_module_name)
|
|
125
193
|
|
|
126
194
|
if self.config.online_run_ut:
|
|
195
|
+
self.inner_switch = False
|
|
127
196
|
return
|
|
128
197
|
|
|
129
198
|
if self.data_collector:
|
|
130
199
|
# 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
|
|
131
200
|
module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
|
|
132
201
|
self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
|
|
202
|
+
self.inner_switch = False
|
|
133
203
|
|
|
134
204
|
pid = os.getpid()
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
205
|
+
full_forward_name = None
|
|
206
|
+
full_backward_name = None
|
|
207
|
+
if module_type == BaseScope.Module_Type_API:
|
|
208
|
+
full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
|
|
209
|
+
full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD
|
|
210
|
+
pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name)
|
|
211
|
+
forward_hook_fn = functools.partial(forward_hook, full_forward_name)
|
|
212
|
+
backward_hook_fn = functools.partial(backward_hook, full_backward_name)
|
|
213
|
+
forward_hook_torch_version_below_2_fn = functools.partial(
|
|
214
|
+
forward_hook_torch_version_below_2,
|
|
215
|
+
full_forward_name
|
|
216
|
+
)
|
|
142
217
|
return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn)
|
|
143
218
|
|
|
144
|
-
def start(self, model
|
|
219
|
+
def start(self, model):
|
|
145
220
|
if self.need_stop_service():
|
|
146
221
|
return
|
|
147
222
|
|
|
@@ -155,10 +230,8 @@ class Service:
|
|
|
155
230
|
|
|
156
231
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
157
232
|
return
|
|
158
|
-
self.
|
|
233
|
+
self.register_module_hook()
|
|
159
234
|
self.first_start = False
|
|
160
|
-
if api_origin:
|
|
161
|
-
api_register.api_modularity()
|
|
162
235
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
163
236
|
run_ut_dispatch(self.attl, True, self.config.online_run_ut_recompute)
|
|
164
237
|
self.switch = True
|
|
@@ -170,30 +243,31 @@ class Service:
|
|
|
170
243
|
def stop(self):
|
|
171
244
|
if self.should_stop_service:
|
|
172
245
|
return
|
|
173
|
-
if self.config.level == "L2":
|
|
174
|
-
return
|
|
175
246
|
if self.config.step and self.current_iter not in self.config.step:
|
|
176
247
|
return
|
|
177
248
|
if self.config.rank and self.current_rank not in self.config.rank:
|
|
178
249
|
return
|
|
179
250
|
self.switch = False
|
|
251
|
+
if self.config.level == Const.LEVEL_L2:
|
|
252
|
+
return
|
|
180
253
|
if self.config.online_run_ut and torch_version_above_or_equal_2:
|
|
181
254
|
run_ut_dispatch(self.attl, False, self.config.online_run_ut_recompute)
|
|
182
255
|
return
|
|
256
|
+
if self.config.async_dump:
|
|
257
|
+
self.data_collector.fill_stack_tensor_data()
|
|
258
|
+
self.data_collector.data_processor.dump_async_data()
|
|
183
259
|
self.data_collector.write_json()
|
|
184
260
|
|
|
185
261
|
def step(self):
|
|
186
262
|
if self.should_stop_service:
|
|
187
263
|
return
|
|
264
|
+
if self.config.async_dump:
|
|
265
|
+
self.data_collector.fill_stack_tensor_data()
|
|
266
|
+
self.data_collector.data_processor.dump_async_data()
|
|
267
|
+
self.data_collector.write_json()
|
|
188
268
|
self.current_iter += 1
|
|
189
269
|
self.data_collector.update_iter(self.current_iter)
|
|
190
|
-
|
|
191
|
-
ModuleProcesser.reset_module_stats()
|
|
192
|
-
HOOKModule.reset_module_stats()
|
|
193
|
-
self.data_collector.data_writer.reset_cache()
|
|
194
|
-
|
|
195
|
-
if self.config.level == Const.LEVEL_L2:
|
|
196
|
-
self.data_collector.data_processor.reset_status()
|
|
270
|
+
self.reset_status()
|
|
197
271
|
|
|
198
272
|
def need_stop_service(self):
|
|
199
273
|
if self.should_stop_service:
|
|
@@ -204,8 +278,6 @@ class Service:
|
|
|
204
278
|
if self.config.online_run_ut:
|
|
205
279
|
# send stop signal if online_run_ut
|
|
206
280
|
self.attl_stop()
|
|
207
|
-
if self.config.level in [Const.LEVEL_L1, Const.LEVEL_L2, Const.LEVEL_MIX]:
|
|
208
|
-
api_register.api_originality()
|
|
209
281
|
self.switch = False
|
|
210
282
|
self.should_stop_service = True
|
|
211
283
|
print_tools_ends_info()
|
|
@@ -214,10 +286,18 @@ class Service:
|
|
|
214
286
|
return True
|
|
215
287
|
return False
|
|
216
288
|
|
|
217
|
-
def should_execute_hook(self):
|
|
218
|
-
|
|
289
|
+
def should_execute_hook(self, hook_type, module, is_forward):
|
|
290
|
+
is_module_hook = hook_type == BaseScope.Module_Type_Module
|
|
291
|
+
if is_module_hook and not self.switch:
|
|
292
|
+
return False
|
|
293
|
+
elif not is_module_hook and is_forward and not self.switch:
|
|
219
294
|
return False
|
|
220
|
-
|
|
295
|
+
elif not is_module_hook and not is_forward and not module.forward_data_collected:
|
|
296
|
+
return False
|
|
297
|
+
|
|
298
|
+
if self.inner_switch:
|
|
299
|
+
return False
|
|
300
|
+
if not self.data_collector or self.data_collector.data_processor.is_terminated:
|
|
221
301
|
return False
|
|
222
302
|
return True
|
|
223
303
|
|
|
@@ -244,50 +324,26 @@ class Service:
|
|
|
244
324
|
construct_file_path = os.path.join(dump_dir, "construct.json")
|
|
245
325
|
free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv")
|
|
246
326
|
self.data_collector.update_dump_paths(
|
|
247
|
-
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \
|
|
259
|
-
module.__class__.__name__ + Const.SEP
|
|
260
|
-
|
|
261
|
-
pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = self.build_hook(
|
|
262
|
-
BaseScope.Module_Type_Module, prefix)
|
|
263
|
-
if torch_version_above_or_equal_2:
|
|
264
|
-
module.register_forward_hook(forward_hook, with_kwargs=True)
|
|
265
|
-
else:
|
|
266
|
-
self.check_register_full_backward_hook(module)
|
|
267
|
-
module.register_full_backward_hook(
|
|
268
|
-
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
269
|
-
module.register_forward_hook(forward_hook_torch_version_below_2)
|
|
270
|
-
self.check_register_full_backward_hook(module)
|
|
271
|
-
module.register_full_backward_hook(backward_hook)
|
|
272
|
-
|
|
273
|
-
module.register_forward_pre_hook(
|
|
274
|
-
self.module_processor.node_hook(prefix + Const.FORWARD, Const.START))
|
|
275
|
-
module.register_forward_hook(
|
|
276
|
-
self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP))
|
|
277
|
-
if torch_version_above_or_equal_2:
|
|
278
|
-
module.register_full_backward_pre_hook(
|
|
279
|
-
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START))
|
|
280
|
-
self.check_register_full_backward_hook(module)
|
|
281
|
-
module.register_full_backward_hook(
|
|
282
|
-
self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP))
|
|
283
|
-
|
|
284
|
-
if self.config.level in ["mix", "L1", "L2"]:
|
|
285
|
-
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API),
|
|
286
|
-
self.config.online_run_ut)
|
|
327
|
+
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path
|
|
328
|
+
)
|
|
329
|
+
self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK)
|
|
330
|
+
|
|
331
|
+
def register_api_hook(self):
|
|
332
|
+
if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]:
|
|
333
|
+
logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.")
|
|
334
|
+
api_register.initialize_hook(
|
|
335
|
+
functools.partial(self.build_hook, BaseScope.Module_Type_API),
|
|
336
|
+
self.config.online_run_ut
|
|
337
|
+
)
|
|
287
338
|
api_register.api_modularity()
|
|
288
339
|
|
|
289
|
-
if
|
|
290
|
-
|
|
340
|
+
if self.config.level == Const.LEVEL_MIX:
|
|
341
|
+
register_optimizer_hook(self.data_collector)
|
|
342
|
+
|
|
343
|
+
def register_module_hook(self):
|
|
344
|
+
if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]:
|
|
345
|
+
logger.info_on_rank_0(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
346
|
+
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
291
347
|
|
|
292
348
|
def attl_init(self):
|
|
293
349
|
if self.config.online_run_ut:
|
|
@@ -319,3 +375,17 @@ class Service:
|
|
|
319
375
|
elif self.attl.socket_manager is not None:
|
|
320
376
|
logger.info(f"pid: {os.getpid()} finished, start send STOP signal.")
|
|
321
377
|
self.attl.socket_manager.send_stop_signal()
|
|
378
|
+
|
|
379
|
+
def reset_status(self):
|
|
380
|
+
ModuleProcesser.reset_module_stats()
|
|
381
|
+
HOOKModule.reset_module_stats()
|
|
382
|
+
self.data_collector.data_writer.reset_cache()
|
|
383
|
+
self.params_grad_info.clear()
|
|
384
|
+
|
|
385
|
+
if self.config.level == Const.LEVEL_L2:
|
|
386
|
+
self.data_collector.data_processor.reset_status()
|
|
387
|
+
return
|
|
388
|
+
if self.config.step and self.current_iter not in self.config.step:
|
|
389
|
+
return
|
|
390
|
+
if self.config.rank and self.current_rank not in self.config.rank:
|
|
391
|
+
return
|