mindstudio-probe 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
- mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
- mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
- mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
- mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
- mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
- msprobe/README.md +182 -0
- msprobe/__init__.py +0 -0
- msprobe/config/README.md +397 -0
- msprobe/config/config.json +28 -0
- msprobe/config/img/free_benchmark.png +0 -0
- msprobe/core/common/const.py +241 -0
- msprobe/core/common/exceptions.py +88 -0
- msprobe/core/common/file_check.py +265 -0
- msprobe/core/common/log.py +55 -0
- msprobe/core/common/utils.py +516 -0
- msprobe/core/common_config.py +58 -0
- msprobe/core/data_dump/data_collector.py +140 -0
- msprobe/core/data_dump/data_processor/base.py +245 -0
- msprobe/core/data_dump/data_processor/factory.py +61 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
- msprobe/core/data_dump/json_writer.py +116 -0
- msprobe/core/data_dump/scope.py +178 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/debugger/__init__.py +0 -0
- msprobe/mindspore/debugger/debugger_config.py +51 -0
- msprobe/mindspore/debugger/precision_debugger.py +32 -0
- msprobe/mindspore/doc/dump.md +65 -0
- msprobe/mindspore/dump/__init__.py +0 -0
- msprobe/mindspore/dump/api_kbk_dump.py +55 -0
- msprobe/mindspore/dump/dump_tool_factory.py +38 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
- msprobe/mindspore/ms_config.py +78 -0
- msprobe/mindspore/overflow_check/__init__.py +0 -0
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
- msprobe/mindspore/task_handler_factory.py +21 -0
- msprobe/msprobe.py +67 -0
- msprobe/pytorch/__init__.py +4 -0
- msprobe/pytorch/advisor/advisor.py +124 -0
- msprobe/pytorch/advisor/advisor_const.py +59 -0
- msprobe/pytorch/advisor/advisor_result.py +58 -0
- msprobe/pytorch/api_accuracy_checker/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
- msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
- msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
- msprobe/pytorch/common/__init__.py +2 -0
- msprobe/pytorch/common/compare_script.template +14 -0
- msprobe/pytorch/common/log.py +32 -0
- msprobe/pytorch/common/parse_json.py +37 -0
- msprobe/pytorch/common/utils.py +224 -0
- msprobe/pytorch/compare/acc_compare.py +1024 -0
- msprobe/pytorch/compare/distributed_compare.py +111 -0
- msprobe/pytorch/compare/highlight.py +100 -0
- msprobe/pytorch/compare/mapping.yaml +607 -0
- msprobe/pytorch/compare/match.py +36 -0
- msprobe/pytorch/compare/npy_compare.py +244 -0
- msprobe/pytorch/debugger/__init__.py +0 -0
- msprobe/pytorch/debugger/debugger_config.py +86 -0
- msprobe/pytorch/debugger/precision_debugger.py +95 -0
- msprobe/pytorch/doc/FAQ.md +193 -0
- msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
- msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
- msprobe/pytorch/doc/dump.md +207 -0
- msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
- msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
- msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
- msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
- msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
- msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
- msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
- msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
- msprobe/pytorch/doc/img/cpu_info.png +0 -0
- msprobe/pytorch/doc/img/module_compare.png +0 -0
- msprobe/pytorch/doc/parse_tool.md +286 -0
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
- msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
- msprobe/pytorch/doc/run_overflow_check.md +25 -0
- msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
- msprobe/pytorch/free_benchmark/__init__.py +8 -0
- msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/common/constant.py +67 -0
- msprobe/pytorch/free_benchmark/common/counter.py +72 -0
- msprobe/pytorch/free_benchmark/common/enums.py +37 -0
- msprobe/pytorch/free_benchmark/common/params.py +129 -0
- msprobe/pytorch/free_benchmark/common/utils.py +98 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
- msprobe/pytorch/free_benchmark/main.py +102 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
- msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
- msprobe/pytorch/functional/__init__.py +0 -0
- msprobe/pytorch/functional/data_processor.py +0 -0
- msprobe/pytorch/functional/dump_module.py +39 -0
- msprobe/pytorch/hook_module/__init__.py +1 -0
- msprobe/pytorch/hook_module/api_registry.py +161 -0
- msprobe/pytorch/hook_module/hook_module.py +109 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
- msprobe/pytorch/hook_module/utils.py +29 -0
- msprobe/pytorch/hook_module/wrap_aten.py +100 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
- msprobe/pytorch/hook_module/wrap_functional.py +108 -0
- msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
- msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
- msprobe/pytorch/hook_module/wrap_torch.py +88 -0
- msprobe/pytorch/hook_module/wrap_vf.py +64 -0
- msprobe/pytorch/module_processer.py +98 -0
- msprobe/pytorch/online_dispatch/__init__.py +20 -0
- msprobe/pytorch/online_dispatch/compare.py +236 -0
- msprobe/pytorch/online_dispatch/dispatch.py +274 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
- msprobe/pytorch/online_dispatch/single_compare.py +391 -0
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
- msprobe/pytorch/online_dispatch/utils.py +187 -0
- msprobe/pytorch/parse.py +4 -0
- msprobe/pytorch/parse_tool/__init__.py +0 -0
- msprobe/pytorch/parse_tool/cli.py +32 -0
- msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
- msprobe/pytorch/parse_tool/lib/compare.py +259 -0
- msprobe/pytorch/parse_tool/lib/config.py +51 -0
- msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
- msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
- msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
- msprobe/pytorch/parse_tool/lib/utils.py +367 -0
- msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
- msprobe/pytorch/pt_config.py +93 -0
- msprobe/pytorch/service.py +167 -0
- msprobe/test/core_ut/common/test_utils.py +345 -0
- msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
- msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
- msprobe/test/core_ut/data_dump/test_scope.py +151 -0
- msprobe/test/core_ut/test_common_config.py +152 -0
- msprobe/test/core_ut/test_file_check.py +218 -0
- msprobe/test/core_ut/test_log.py +109 -0
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
- msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
- msprobe/test/mindspore_ut/test_ms_config.py +69 -0
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
- msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
- msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
- msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
- msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
- msprobe/test/pytorch_ut/test_pt_config.py +69 -0
- msprobe/test/pytorch_ut/test_service.py +59 -0
- msprobe/test/resources/advisor.txt +3 -0
- msprobe/test/resources/compare_result_20230703104808.csv +9 -0
- msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
- msprobe/test/resources/config.yaml +3 -0
- msprobe/test/resources/npu_test.pkl +8 -0
- msprobe/test/run_test.sh +30 -0
- msprobe/test/run_ut.py +58 -0
- msprobe/test/test_module_processer.py +64 -0
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
# Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
"""
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import re
|
|
20
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, \
|
|
21
|
+
check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid
|
|
22
|
+
from msprobe.pytorch.compare.acc_compare import compare_core
|
|
23
|
+
from msprobe.core.common.file_check import create_directory
|
|
24
|
+
from msprobe.pytorch.common.log import logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
|
|
28
|
+
def check_and_return_dir_contents(dump_dir, prefix):
|
|
29
|
+
"""
|
|
30
|
+
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
31
|
+
pattern: ^{prefix}(?:0|[0-9][1-9]*)?$
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
dump_dir (str): dump dir
|
|
35
|
+
prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
content [list]: dir contents
|
|
39
|
+
Raises:
|
|
40
|
+
CompareException: invalid path
|
|
41
|
+
ValueError: prefix not match the patterns
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
check_regex_prefix_format_valid(prefix)
|
|
45
|
+
check_file_or_directory_path(dump_dir, True)
|
|
46
|
+
contents = os.listdir(dump_dir)
|
|
47
|
+
pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$')
|
|
48
|
+
for name in contents:
|
|
49
|
+
if not pattern.match(name):
|
|
50
|
+
logger.error(
|
|
51
|
+
f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump "
|
|
52
|
+
f"output. Please check and delete irrelevant files in {dump_dir} and try again."
|
|
53
|
+
)
|
|
54
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
55
|
+
return contents
|
|
56
|
+
|
|
57
|
+
def extract_json(dirname, stack_json=False):
|
|
58
|
+
json_path = ''
|
|
59
|
+
for fname in os.listdir(dirname):
|
|
60
|
+
full_path = os.path.join(dirname, fname)
|
|
61
|
+
if full_path.endswith('.json'):
|
|
62
|
+
json_path = full_path
|
|
63
|
+
if not stack_json and 'stack' not in json_path:
|
|
64
|
+
break
|
|
65
|
+
if stack_json and 'stack' in json_path:
|
|
66
|
+
break
|
|
67
|
+
|
|
68
|
+
# Provide robustness on invalid directory inputs
|
|
69
|
+
if not json_path:
|
|
70
|
+
logger.error(f'No file is found in dump dir {dirname}. ')
|
|
71
|
+
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
72
|
+
return json_path
|
|
73
|
+
|
|
74
|
+
if kwargs.get('suffix'):
|
|
75
|
+
logger.error("Argument 'suffix' is not supported for compare_distributed.")
|
|
76
|
+
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
77
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
78
|
+
auto_analyze = kwargs.get('auto_analyze', True)
|
|
79
|
+
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
80
|
+
# get the ranks and match by order
|
|
81
|
+
npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank'))
|
|
82
|
+
bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank'))
|
|
83
|
+
if len(npu_ranks) != len(bench_ranks):
|
|
84
|
+
logger.error('The number of ranks in the two runs are different. '
|
|
85
|
+
'Unable to match the ranks. Please use another folder to compare '
|
|
86
|
+
'or use compare() api and manually match the ranks.')
|
|
87
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
88
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
89
|
+
n_dir = os.path.join(npu_dump_dir, nr)
|
|
90
|
+
b_dir = os.path.join(bench_dump_dir, br)
|
|
91
|
+
s_dir = b_dir
|
|
92
|
+
npu_json_path = extract_json(n_dir, stack_json=False)
|
|
93
|
+
bench_json_path = extract_json(b_dir, stack_json=False)
|
|
94
|
+
stack_json_path = extract_json(s_dir, stack_json=True)
|
|
95
|
+
|
|
96
|
+
dump_result_param = {
|
|
97
|
+
'npu_json_path': npu_json_path,
|
|
98
|
+
'bench_json_path': bench_json_path,
|
|
99
|
+
'stack_json_path': stack_json_path,
|
|
100
|
+
'is_print_compare_log': True
|
|
101
|
+
}
|
|
102
|
+
try:
|
|
103
|
+
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
|
|
104
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
105
|
+
create_directory(output_path)
|
|
106
|
+
check_compare_param(dump_result_param, output_path, stack_mode=stack_mode, summary_compare=summary_compare)
|
|
107
|
+
except CompareException as error:
|
|
108
|
+
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
109
|
+
sys.exit(error.code)
|
|
110
|
+
compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
|
|
111
|
+
md5_compare=md5_compare, **kwargs)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import abc
|
|
3
|
+
import numpy as np
|
|
4
|
+
from msprobe.core.common.utils import get_header_index
|
|
5
|
+
from msprobe.core.common.const import CompareConst
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HighlightCheck(abc.ABC):
|
|
9
|
+
@abc.abstractmethod
|
|
10
|
+
def apply(self, info, color_columns, summary_compare):
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CheckOrderMagnitude(HighlightCheck):
|
|
15
|
+
"""检查Max diff的数量级差异"""
|
|
16
|
+
def apply(self, info, color_columns, summary_compare=True):
|
|
17
|
+
api_in, api_out, num = info
|
|
18
|
+
max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
|
|
19
|
+
if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
|
|
20
|
+
return
|
|
21
|
+
in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
|
|
22
|
+
out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
|
|
23
|
+
if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
|
|
24
|
+
color_columns.yellow.append(num)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class CheckOneThousandErrorRatio(HighlightCheck):
|
|
28
|
+
"""检查千分误差比率"""
|
|
29
|
+
def apply(self, info, color_columns, summary_compare=True):
|
|
30
|
+
api_in, api_out, num = info
|
|
31
|
+
one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
|
|
32
|
+
if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
|
|
33
|
+
return
|
|
34
|
+
if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
|
|
35
|
+
color_columns.red.append(num)
|
|
36
|
+
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
37
|
+
color_columns.yellow.append(num)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CheckCosineSimilarity(HighlightCheck):
|
|
41
|
+
"""检查余弦相似度"""
|
|
42
|
+
def apply(self, info, color_columns, summary_compare=True):
|
|
43
|
+
api_in, api_out, num = info
|
|
44
|
+
cosine_index = get_header_index('Cosine', summary_compare)
|
|
45
|
+
if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
|
|
46
|
+
return
|
|
47
|
+
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
48
|
+
color_columns.yellow.append(num)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CheckMaxRelativeDiff(HighlightCheck):
|
|
52
|
+
"""检查最大相对差异"""
|
|
53
|
+
def apply(self, info, color_columns, summary_compare=True):
|
|
54
|
+
api_in, api_out, num = info
|
|
55
|
+
max_diff_index = get_header_index('Max diff', summary_compare)
|
|
56
|
+
bench_max_index = get_header_index('Bench max', summary_compare)
|
|
57
|
+
input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
|
|
58
|
+
output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
|
|
59
|
+
if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
|
|
60
|
+
(float, int)):
|
|
61
|
+
return
|
|
62
|
+
if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
|
|
63
|
+
color_columns.red.append(num)
|
|
64
|
+
elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
|
|
65
|
+
color_columns.yellow.append(num)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class CheckOverflow(HighlightCheck):
|
|
69
|
+
"""检查是否存在溢出"""
|
|
70
|
+
def apply(self, info, color_columns, summary_compare=True):
|
|
71
|
+
line, num = info
|
|
72
|
+
npu_max_index = get_header_index('NPU max', summary_compare)
|
|
73
|
+
npu_min_index = get_header_index('NPU min', summary_compare)
|
|
74
|
+
max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
|
|
75
|
+
if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
|
|
76
|
+
line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
|
|
77
|
+
color_columns.red.append(num)
|
|
78
|
+
return
|
|
79
|
+
# check if Max_Diff > 1e+10
|
|
80
|
+
if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
|
|
81
|
+
color_columns.red.append(num)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class HighlightRules:
|
|
85
|
+
"""高亮规则集合,用于检查API的误差"""
|
|
86
|
+
# 适用于每行的规则
|
|
87
|
+
basic_rules = {
|
|
88
|
+
"check_overflow": CheckOverflow()
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
# 用于比较输入和输出的规则
|
|
92
|
+
compare_rules = {
|
|
93
|
+
"check_order_magnitude": CheckOrderMagnitude(),
|
|
94
|
+
"check_one_thousand_error": CheckOneThousandErrorRatio(),
|
|
95
|
+
"check_cosine_similarity": CheckCosineSimilarity()
|
|
96
|
+
}
|
|
97
|
+
summary_compare_rules = {
|
|
98
|
+
"check_order_magnitude": CheckOrderMagnitude(),
|
|
99
|
+
"check_max_relative_diff": CheckMaxRelativeDiff(),
|
|
100
|
+
}
|