mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -15,41 +15,48 @@
|
|
|
15
15
|
|
|
16
16
|
import multiprocessing
|
|
17
17
|
import os
|
|
18
|
+
import re
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
|
|
18
21
|
import pandas as pd
|
|
19
|
-
from
|
|
20
|
-
from msprobe.core.common.file_utils import load_json
|
|
22
|
+
from msprobe.core.advisor.advisor import Advisor
|
|
21
23
|
from msprobe.core.common.const import CompareConst, Const
|
|
22
24
|
from msprobe.core.common.exceptions import FileCheckException
|
|
23
|
-
from msprobe.core.common.
|
|
24
|
-
from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
|
|
25
|
+
from msprobe.core.common.file_utils import load_json
|
|
25
26
|
from msprobe.core.common.file_utils import remove_path
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid, safe_get_value
|
|
26
29
|
from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
|
|
27
|
-
|
|
30
|
+
check_stack_json_str
|
|
28
31
|
from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
|
|
29
|
-
from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
|
|
30
32
|
from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
|
|
31
33
|
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
|
|
32
34
|
get_error_message
|
|
33
|
-
from msprobe.core.
|
|
35
|
+
from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy, \
|
|
36
|
+
get_rela_diff_summary_mode, print_compare_ends_info
|
|
37
|
+
from tqdm import tqdm
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
class Comparator:
|
|
37
|
-
|
|
41
|
+
|
|
38
42
|
def __init__(self):
|
|
39
43
|
pass
|
|
40
44
|
|
|
41
45
|
@staticmethod
|
|
42
46
|
def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
47
|
+
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
48
|
+
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
49
|
+
|
|
50
|
+
if len(npu_struct) < 3 or len(bench_struct) < 3:
|
|
51
|
+
logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
|
|
52
|
+
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
|
|
53
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
54
|
+
|
|
55
|
+
result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
|
|
56
|
+
npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
|
|
57
|
+
CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
|
|
58
|
+
|
|
59
|
+
if len(args) >= 2 and args[0]:
|
|
53
60
|
result_item.extend(args[1])
|
|
54
61
|
else:
|
|
55
62
|
result_item.append(CompareConst.NONE)
|
|
@@ -58,59 +65,47 @@ class Comparator:
|
|
|
58
65
|
@staticmethod
|
|
59
66
|
def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
|
|
60
67
|
err_msg = ""
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
64
|
-
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
65
|
-
diff = npu_val - bench_val
|
|
66
|
-
if bench_val != 0:
|
|
67
|
-
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
68
|
-
else:
|
|
69
|
-
relative = "N/A"
|
|
70
|
-
result_item[start_idx + i] = diff
|
|
71
|
-
result_item[start_idx + i + 4] = relative
|
|
72
|
-
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
73
|
-
if magnitude_diff > 0.5:
|
|
74
|
-
warning_flag = True
|
|
75
|
-
else:
|
|
76
|
-
result_item[start_idx + i] = CompareConst.NONE
|
|
77
|
-
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
78
|
-
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
79
|
-
for i in range(start_idx, len(result_item)):
|
|
80
|
-
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
81
|
-
result_item[i] = f'{result_item[i]}\t'
|
|
68
|
+
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
69
|
+
bench_summary_data, err_msg)
|
|
82
70
|
result_item.append(accuracy_check)
|
|
83
71
|
result_item.append(err_msg)
|
|
84
|
-
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _generate_na_data(ops_all):
|
|
75
|
+
if not ops_all:
|
|
76
|
+
return {}
|
|
77
|
+
key = next(iter(ops_all))
|
|
78
|
+
value = deepcopy(ops_all[key])
|
|
79
|
+
for k, v in value.items():
|
|
80
|
+
if isinstance(v, tuple):
|
|
81
|
+
value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
|
|
82
|
+
elif isinstance(v, list):
|
|
83
|
+
value[k] = [CompareConst.N_A] * len(v)
|
|
84
|
+
else:
|
|
85
|
+
value[k] = CompareConst.N_A
|
|
86
|
+
return value
|
|
87
|
+
|
|
85
88
|
@classmethod
|
|
86
|
-
def make_result_table(cls, result,
|
|
87
|
-
|
|
88
|
-
header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
|
|
89
|
-
elif summary_compare:
|
|
90
|
-
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
91
|
-
else:
|
|
92
|
-
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
89
|
+
def make_result_table(cls, result, stack_mode, dump_mode):
|
|
90
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
|
|
93
91
|
|
|
94
|
-
all_mode_bool = not (summary_compare or md5_compare)
|
|
95
92
|
if stack_mode:
|
|
96
|
-
|
|
97
|
-
|
|
93
|
+
header.append(CompareConst.STACK)
|
|
94
|
+
if dump_mode == Const.ALL:
|
|
98
95
|
header.append(CompareConst.DATA_NAME)
|
|
99
|
-
else:
|
|
100
|
-
header.append(CompareConst.STACK)
|
|
101
96
|
else:
|
|
102
|
-
if
|
|
97
|
+
if dump_mode == Const.ALL:
|
|
103
98
|
for row in result:
|
|
104
|
-
del row[-2]
|
|
99
|
+
del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
|
|
105
100
|
header.append(CompareConst.DATA_NAME)
|
|
106
101
|
else:
|
|
107
102
|
for row in result:
|
|
108
|
-
del row[-1]
|
|
103
|
+
del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
|
|
109
104
|
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
110
|
-
return result_df
|
|
111
|
-
|
|
105
|
+
return result_df
|
|
106
|
+
|
|
112
107
|
@classmethod
|
|
113
|
-
def gen_merge_list(cls, json_data, op_name, stack_json_data,
|
|
108
|
+
def gen_merge_list(cls, json_data, op_name, stack_json_data, dump_mode):
|
|
114
109
|
op_data = json_data['data'][op_name]
|
|
115
110
|
check_dump_json_str(op_data, op_name)
|
|
116
111
|
op_parsed_list = read_op(op_data, op_name)
|
|
@@ -122,31 +117,32 @@ class Comparator:
|
|
|
122
117
|
'full_op_name': op_name,
|
|
123
118
|
'full_info': stack_info
|
|
124
119
|
})
|
|
125
|
-
|
|
126
|
-
merge_list = merge_tensor(op_parsed_list,
|
|
120
|
+
|
|
121
|
+
merge_list = merge_tensor(op_parsed_list, dump_mode)
|
|
127
122
|
return merge_list
|
|
128
|
-
|
|
123
|
+
|
|
129
124
|
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
graph_mode = check_graph_mode(
|
|
133
|
-
|
|
125
|
+
npu_op_name = npu_dict[CompareConst.OP_NAME]
|
|
126
|
+
bench_op_name = bench_dict[CompareConst.OP_NAME]
|
|
127
|
+
graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
|
|
128
|
+
safe_get_value(bench_op_name, 0, "bench_op_name"))
|
|
129
|
+
|
|
134
130
|
frame_name = getattr(self, "frame_name")
|
|
135
131
|
if frame_name == "PTComparator":
|
|
136
132
|
from msprobe.pytorch.compare.match import graph_mapping
|
|
137
133
|
if graph_mode:
|
|
138
|
-
return graph_mapping.match(
|
|
134
|
+
return graph_mapping.match(npu_op_name[0], bench_op_name[0])
|
|
139
135
|
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
140
136
|
if not fuzzy_match:
|
|
141
|
-
return
|
|
137
|
+
return npu_op_name == bench_op_name and struct_match
|
|
142
138
|
is_match = True
|
|
143
139
|
try:
|
|
144
|
-
is_match = fuzzy_check_op(
|
|
140
|
+
is_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
145
141
|
except Exception as err:
|
|
146
|
-
logger.warning("%s and %s can not fuzzy match." % (
|
|
142
|
+
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
147
143
|
is_match = False
|
|
148
144
|
return is_match and struct_match
|
|
149
|
-
|
|
145
|
+
|
|
150
146
|
def match_op(self, npu_queue, bench_queue, fuzzy_match):
|
|
151
147
|
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
152
148
|
if self.check_op(npu_queue[-1], b_op, fuzzy_match):
|
|
@@ -157,8 +153,8 @@ class Comparator:
|
|
|
157
153
|
if self.check_op(n_op, bench_queue[-1], fuzzy_match):
|
|
158
154
|
return n_index, len(bench_queue) - 1
|
|
159
155
|
return -1, -1
|
|
160
|
-
|
|
161
|
-
def compare_process(self, file_lists, stack_mode, fuzzy_match,
|
|
156
|
+
|
|
157
|
+
def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
|
|
162
158
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
163
159
|
npu_json_data = load_json(npu_json_path)
|
|
164
160
|
bench_json_data = load_json(bench_json_path)
|
|
@@ -189,8 +185,7 @@ class Comparator:
|
|
|
189
185
|
op_name_npu = next(ops_npu_iter)
|
|
190
186
|
check_op_str_pattern_valid(op_name_npu)
|
|
191
187
|
read_err_npu = True
|
|
192
|
-
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
|
|
193
|
-
summary_compare, md5_compare)
|
|
188
|
+
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data, dump_mode)
|
|
194
189
|
if npu_merge_list:
|
|
195
190
|
npu_ops_queue.append(npu_merge_list)
|
|
196
191
|
except StopIteration:
|
|
@@ -199,8 +194,7 @@ class Comparator:
|
|
|
199
194
|
last_bench_ops_len = len(bench_ops_queue)
|
|
200
195
|
op_name_bench = next(ops_bench_iter)
|
|
201
196
|
check_op_str_pattern_valid(op_name_bench)
|
|
202
|
-
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
|
|
203
|
-
summary_compare, md5_compare)
|
|
197
|
+
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data, dump_mode)
|
|
204
198
|
if bench_merge_list:
|
|
205
199
|
bench_ops_queue.append(bench_merge_list)
|
|
206
200
|
except StopIteration:
|
|
@@ -226,71 +220,93 @@ class Comparator:
|
|
|
226
220
|
b_match_data = bench_ops_queue[b_match_point]
|
|
227
221
|
un_match_data = npu_ops_queue[0: n_match_point]
|
|
228
222
|
for npu_data in un_match_data:
|
|
229
|
-
get_un_match_accuracy(result, npu_data,
|
|
230
|
-
get_accuracy(result, n_match_data, b_match_data,
|
|
223
|
+
get_un_match_accuracy(result, npu_data, dump_mode)
|
|
224
|
+
get_accuracy(result, n_match_data, b_match_data, dump_mode)
|
|
231
225
|
del npu_ops_queue[0: n_match_point + 1]
|
|
232
226
|
del bench_ops_queue[0: b_match_point + 1]
|
|
227
|
+
progress_bar.close()
|
|
233
228
|
if npu_ops_queue:
|
|
234
229
|
for npu_data in npu_ops_queue:
|
|
235
|
-
get_un_match_accuracy(result, npu_data,
|
|
236
|
-
|
|
237
|
-
result_df = self.make_result_table(result,
|
|
230
|
+
get_un_match_accuracy(result, npu_data, dump_mode)
|
|
231
|
+
|
|
232
|
+
result_df = self.make_result_table(result, stack_mode, dump_mode)
|
|
238
233
|
return result_df
|
|
239
234
|
|
|
240
|
-
def merge_data(self, json_data, stack_json_data,
|
|
235
|
+
def merge_data(self, json_data, stack_json_data, dump_mode):
|
|
241
236
|
ops_all = {}
|
|
242
237
|
for op_name in json_data.get('data', {}):
|
|
243
|
-
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data,
|
|
244
|
-
md5_compare)
|
|
238
|
+
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, dump_mode)
|
|
245
239
|
if merge_list:
|
|
246
240
|
input_index, output_index = 0, 0
|
|
247
|
-
for index, input_or_output in enumerate(merge_list[
|
|
241
|
+
for index, input_or_output in enumerate(merge_list[CompareConst.OP_NAME]):
|
|
248
242
|
input_or_output_list = input_or_output.split(Const.SEP)
|
|
249
243
|
data_name = merge_list.get('data_name')
|
|
250
244
|
data_name = data_name[index] if data_name else None
|
|
251
245
|
if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
|
|
252
|
-
ops_all[input_or_output] = {
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
246
|
+
ops_all[input_or_output] = {
|
|
247
|
+
CompareConst.STRUCT: safe_get_value(merge_list, input_index, "merge_list",
|
|
248
|
+
key=CompareConst.INPUT_STRUCT),
|
|
249
|
+
CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
|
|
250
|
+
key=CompareConst.SUMMARY),
|
|
251
|
+
'data_name': data_name,
|
|
252
|
+
'stack_info': merge_list.get('stack_info')
|
|
253
|
+
}
|
|
256
254
|
input_index += 1
|
|
257
255
|
|
|
258
256
|
elif Const.OUTPUT in input_or_output_list:
|
|
259
|
-
ops_all[input_or_output] = {
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
257
|
+
ops_all[input_or_output] = {
|
|
258
|
+
CompareConst.STRUCT: safe_get_value(merge_list, output_index, "merge_list",
|
|
259
|
+
key=CompareConst.OUTPUT_STRUCT),
|
|
260
|
+
CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
|
|
261
|
+
key=CompareConst.SUMMARY),
|
|
262
|
+
'data_name': data_name,
|
|
263
|
+
'stack_info': merge_list.get('stack_info')
|
|
264
|
+
}
|
|
263
265
|
output_index += 1
|
|
264
266
|
return ops_all
|
|
265
267
|
|
|
266
|
-
def get_accuracy(self, npu_ops_all, bench_ops_all,
|
|
268
|
+
def get_accuracy(self, npu_ops_all, bench_ops_all, dump_mode):
|
|
267
269
|
result = []
|
|
270
|
+
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
268
271
|
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
269
272
|
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
270
273
|
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
271
274
|
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
272
275
|
has_stack = npu_stack_info and bench_stack_info
|
|
273
|
-
if
|
|
276
|
+
if dump_mode == Const.MD5:
|
|
274
277
|
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
275
278
|
bench_ops_all, has_stack, npu_stack_info))
|
|
276
279
|
continue
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
280
|
+
|
|
281
|
+
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
282
|
+
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
283
|
+
|
|
284
|
+
if len(npu_struct) < 2 or len(bench_struct) < 2:
|
|
285
|
+
logger.error(
|
|
286
|
+
f"The length of npu_struct and bench_struct must be >= 2, "
|
|
287
|
+
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
|
|
288
|
+
f"Please check!"
|
|
289
|
+
)
|
|
290
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
291
|
+
|
|
292
|
+
base_result_item = [
|
|
293
|
+
ms_op_name, bench_op_name,
|
|
294
|
+
npu_struct[0],
|
|
295
|
+
bench_struct[0],
|
|
296
|
+
npu_struct[1],
|
|
297
|
+
bench_struct[1]
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
if dump_mode == Const.SUMMARY:
|
|
301
|
+
result_item = base_result_item + [" "] * 8
|
|
283
302
|
else:
|
|
284
|
-
result_item =
|
|
285
|
-
|
|
286
|
-
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
287
|
-
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
288
|
-
" ", " ", " ", " ", " "]
|
|
303
|
+
result_item = base_result_item + [" "] * 5
|
|
304
|
+
|
|
289
305
|
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
290
306
|
result_item.extend(npu_summary_data)
|
|
291
307
|
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
292
308
|
result_item.extend(bench_summary_data)
|
|
293
|
-
if
|
|
309
|
+
if dump_mode == Const.SUMMARY:
|
|
294
310
|
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
295
311
|
else:
|
|
296
312
|
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
@@ -299,7 +315,7 @@ class Comparator:
|
|
|
299
315
|
result_item.extend(npu_stack_info)
|
|
300
316
|
else:
|
|
301
317
|
result_item.append(CompareConst.NONE)
|
|
302
|
-
if
|
|
318
|
+
if dump_mode == Const.ALL:
|
|
303
319
|
result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
|
|
304
320
|
result.append(result_item)
|
|
305
321
|
elif ms_op_name not in npu_ops_all:
|
|
@@ -308,26 +324,40 @@ class Comparator:
|
|
|
308
324
|
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
309
325
|
return result
|
|
310
326
|
|
|
311
|
-
def compare_process_custom(self, file_lists, stack_mode,
|
|
327
|
+
def compare_process_custom(self, file_lists, stack_mode, dump_mode):
|
|
312
328
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
313
329
|
npu_json_data = load_json(npu_json_path)
|
|
314
330
|
bench_json_data = load_json(bench_json_path)
|
|
315
331
|
stack_json_data = load_json(stack_json_path)
|
|
316
332
|
|
|
317
|
-
npu_ops_all = self.merge_data(npu_json_data, stack_json_data,
|
|
318
|
-
bench_ops_all = self.merge_data(bench_json_data, stack_json_data,
|
|
333
|
+
npu_ops_all = self.merge_data(npu_json_data, stack_json_data, dump_mode)
|
|
334
|
+
bench_ops_all = self.merge_data(bench_json_data, stack_json_data, dump_mode)
|
|
319
335
|
|
|
320
|
-
result = self.get_accuracy(npu_ops_all, bench_ops_all,
|
|
321
|
-
result_df = self.make_result_table(result,
|
|
336
|
+
result = self.get_accuracy(npu_ops_all, bench_ops_all, dump_mode)
|
|
337
|
+
result_df = self.make_result_table(result, stack_mode, dump_mode)
|
|
322
338
|
return result_df
|
|
323
339
|
|
|
324
|
-
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
340
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
|
|
341
|
+
"""
|
|
342
|
+
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
343
|
+
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
344
|
+
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
345
|
+
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
346
|
+
:param bench_data: bench的dump数据中"data"字段
|
|
347
|
+
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
348
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
|
|
349
|
+
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
350
|
+
"""
|
|
325
351
|
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
326
|
-
data_name = npu_bench_name_list
|
|
352
|
+
data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
|
|
327
353
|
error_file, relative_err, error_flag = None, None, False
|
|
354
|
+
bench_data_name = get_bench_data_name(bench_op_name, bench_data)
|
|
328
355
|
if data_name == '-1' or data_name == -1: # 没有真实数据路径
|
|
329
356
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
330
357
|
error_flag = True
|
|
358
|
+
elif not bench_data_name:
|
|
359
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
360
|
+
error_file = 'no_bench_data'
|
|
331
361
|
else:
|
|
332
362
|
try:
|
|
333
363
|
read_npy_data = getattr(self, "read_npy_data")
|
|
@@ -335,19 +365,18 @@ class Comparator:
|
|
|
335
365
|
if frame_name == "MSComparator":
|
|
336
366
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
|
|
337
367
|
if self.cross_frame:
|
|
338
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
339
|
-
|
|
368
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
|
|
369
|
+
load_pt_file=True)
|
|
340
370
|
else:
|
|
341
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
342
|
-
bench_op_name + Const.NUMPY_SUFFIX)
|
|
371
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
|
|
343
372
|
else:
|
|
344
373
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
|
|
345
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
374
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
|
|
346
375
|
except IOError as error:
|
|
347
376
|
error_file = error.filename
|
|
348
377
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
349
378
|
error_flag = True
|
|
350
|
-
except FileCheckException:
|
|
379
|
+
except (FileCheckException, CompareException):
|
|
351
380
|
error_file = data_name
|
|
352
381
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
353
382
|
error_flag = True
|
|
@@ -364,7 +393,7 @@ class Comparator:
|
|
|
364
393
|
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
365
394
|
result_list.append(err_msg)
|
|
366
395
|
return result_list
|
|
367
|
-
|
|
396
|
+
|
|
368
397
|
def compare_core(self, input_parma, output_path, **kwargs):
|
|
369
398
|
"""
|
|
370
399
|
Compares data from multiple JSON files and generates a comparison report.
|
|
@@ -378,8 +407,7 @@ class Comparator:
|
|
|
378
407
|
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
379
408
|
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
380
409
|
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
381
|
-
-
|
|
382
|
-
- md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
|
|
410
|
+
- dump_mode (str): ALL, SUMMARY, MD5.
|
|
383
411
|
|
|
384
412
|
Returns:
|
|
385
413
|
"""
|
|
@@ -388,41 +416,43 @@ class Comparator:
|
|
|
388
416
|
auto_analyze = kwargs.get('auto_analyze', True)
|
|
389
417
|
suffix = kwargs.get('suffix', '')
|
|
390
418
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
391
|
-
|
|
392
|
-
md5_compare = kwargs.get('md5_compare', False)
|
|
419
|
+
dump_mode = kwargs.get('dump_mode', None)
|
|
393
420
|
|
|
394
421
|
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
395
422
|
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
396
423
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
397
424
|
remove_path(file_path)
|
|
398
|
-
highlight_dict = {
|
|
425
|
+
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
399
426
|
|
|
400
427
|
npu_json = input_parma.get("npu_json_path")
|
|
401
428
|
bench_json = input_parma.get("bench_json_path")
|
|
402
429
|
stack_json = input_parma.get("stack_json_path")
|
|
403
430
|
if self.data_mapping:
|
|
404
|
-
result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
|
|
405
|
-
summary_compare, md5_compare)
|
|
431
|
+
result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode, dump_mode)
|
|
406
432
|
else:
|
|
407
|
-
result_df = self.compare_process(
|
|
408
|
-
|
|
433
|
+
result_df = self.compare_process(
|
|
434
|
+
[npu_json, bench_json, stack_json],
|
|
435
|
+
stack_mode,
|
|
436
|
+
fuzzy_match,
|
|
437
|
+
dump_mode
|
|
438
|
+
)
|
|
409
439
|
|
|
410
440
|
if not result_df.values.tolist():
|
|
411
441
|
logger.warning("Can`t match any op.")
|
|
412
442
|
return
|
|
413
443
|
|
|
414
|
-
if
|
|
415
|
-
result_df = self.
|
|
444
|
+
if dump_mode == Const.ALL:
|
|
445
|
+
result_df = self.do_multi_process(input_parma, result_df)
|
|
416
446
|
|
|
417
|
-
|
|
418
|
-
find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
|
|
447
|
+
find_compare_result_error_rows(result_df, highlight_dict, dump_mode)
|
|
419
448
|
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
420
|
-
logger.info("Highlight suspicious API/Module finish.")
|
|
421
449
|
|
|
422
450
|
if auto_analyze:
|
|
423
451
|
advisor = Advisor(result_df, output_path, suffix)
|
|
424
452
|
advisor.analysis()
|
|
425
|
-
|
|
453
|
+
|
|
454
|
+
print_compare_ends_info()
|
|
455
|
+
|
|
426
456
|
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
427
457
|
cos_result = []
|
|
428
458
|
max_err_result = []
|
|
@@ -431,13 +461,16 @@ class Comparator:
|
|
|
431
461
|
one_thousand_err_ratio_result = []
|
|
432
462
|
five_thousand_err_ratio_result = []
|
|
433
463
|
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
464
|
+
bench_data = load_json(input_param.get("bench_json_path")).get('data')
|
|
434
465
|
for i in range(len(result_df)):
|
|
435
466
|
npu_op_name = result_df.iloc[i, 0]
|
|
436
467
|
bench_op_name = result_df.iloc[i, 1]
|
|
437
468
|
if is_print_compare_log:
|
|
438
469
|
logger.info("start compare: {}".format(npu_op_name))
|
|
470
|
+
|
|
439
471
|
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
|
|
440
|
-
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
472
|
+
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
|
|
473
|
+
|
|
441
474
|
if is_print_compare_log:
|
|
442
475
|
logger.info(
|
|
443
476
|
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
@@ -460,9 +493,9 @@ class Comparator:
|
|
|
460
493
|
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
461
494
|
)
|
|
462
495
|
|
|
463
|
-
return _save_cmp_result(idx, cr, result_df, lock)
|
|
464
|
-
|
|
465
|
-
def
|
|
496
|
+
return _save_cmp_result(idx, cr, result_df, lock)
|
|
497
|
+
|
|
498
|
+
def do_multi_process(self, input_parma, result_df):
|
|
466
499
|
try:
|
|
467
500
|
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
468
501
|
multiprocessing.Manager().RLock())
|
|
@@ -470,4 +503,36 @@ class Comparator:
|
|
|
470
503
|
except ValueError as e:
|
|
471
504
|
logger.error('result dataframe is not found.')
|
|
472
505
|
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
473
|
-
|
|
506
|
+
|
|
507
|
+
def get_bench_data_name(bench_op_name, bench_data):
|
|
508
|
+
bench_name_list = re.split(r'\.(input|output|kwargs)\.', bench_op_name)
|
|
509
|
+
bench_data_bundle = bench_data.get(bench_name_list[0], {})
|
|
510
|
+
if not bench_data_bundle or len(bench_name_list) < 3:
|
|
511
|
+
return None
|
|
512
|
+
layers = bench_name_list[2].split(Const.SEP)
|
|
513
|
+
|
|
514
|
+
def get(key, container):
|
|
515
|
+
if isinstance(container, dict):
|
|
516
|
+
return container.get(key)
|
|
517
|
+
if isinstance(container, list):
|
|
518
|
+
try:
|
|
519
|
+
return container[int(key)]
|
|
520
|
+
except (ValueError, IndexError):
|
|
521
|
+
return None
|
|
522
|
+
return None
|
|
523
|
+
|
|
524
|
+
def get_by_layer(container):
|
|
525
|
+
data = container
|
|
526
|
+
for layer in layers:
|
|
527
|
+
data = get(layer, data)
|
|
528
|
+
return get(CompareConst.DATA_NAME.lower(), data)
|
|
529
|
+
|
|
530
|
+
if Const.INPUT == bench_name_list[1]:
|
|
531
|
+
return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
|
|
532
|
+
elif Const.KWARGS == bench_name_list[1]:
|
|
533
|
+
return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
|
|
534
|
+
elif Const.OUTPUT == bench_name_list[1]:
|
|
535
|
+
return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
|
|
536
|
+
else:
|
|
537
|
+
return None
|
|
538
|
+
|
msprobe/core/compare/check.py
CHANGED
|
@@ -35,18 +35,15 @@ dtype_mapping = {
|
|
|
35
35
|
"BFloat16": "torch.bfloat16",
|
|
36
36
|
"Complex64": "torch.complex64",
|
|
37
37
|
"Complex128": "torch.complex128"
|
|
38
|
-
|
|
38
|
+
}
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def check_struct_match(npu_dict, bench_dict
|
|
41
|
+
def check_struct_match(npu_dict, bench_dict):
|
|
42
42
|
npu_struct_in = npu_dict.get("input_struct")
|
|
43
43
|
bench_struct_in = bench_dict.get("input_struct")
|
|
44
44
|
npu_struct_out = npu_dict.get("output_struct")
|
|
45
45
|
bench_struct_out = bench_dict.get("output_struct")
|
|
46
46
|
|
|
47
|
-
if cross_frame:
|
|
48
|
-
npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
|
|
49
|
-
npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
|
|
50
47
|
is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
|
|
51
48
|
if not is_match:
|
|
52
49
|
if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
|
|
@@ -14,17 +14,22 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
from msprobe.core.common.file_utils import
|
|
17
|
+
from msprobe.core.common.file_utils import check_file_type, load_json
|
|
18
18
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
19
19
|
from msprobe.core.common.utils import CompareException
|
|
20
20
|
from msprobe.core.common.log import logger
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def compare_cli(args):
|
|
24
|
-
|
|
25
|
-
input_param = json.load(file)
|
|
24
|
+
input_param = load_json(args.input_path)
|
|
26
25
|
npu_path = input_param.get("npu_path", None)
|
|
27
26
|
bench_path = input_param.get("bench_path", None)
|
|
27
|
+
if not npu_path:
|
|
28
|
+
logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!")
|
|
29
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
30
|
+
if not bench_path:
|
|
31
|
+
logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!")
|
|
32
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
28
33
|
frame_name = args.framework
|
|
29
34
|
auto_analyze = not args.compare_only
|
|
30
35
|
if frame_name == Const.PT_FRAMEWORK:
|
|
@@ -34,6 +39,9 @@ def compare_cli(args):
|
|
|
34
39
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
35
40
|
from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
|
|
36
41
|
if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
|
|
42
|
+
if "stack_path" not in input_param:
|
|
43
|
+
logger.error(f"Missing stack_path in configuration file {args.input_path}, please check!")
|
|
44
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
37
45
|
input_param["npu_json_path"] = input_param.pop("npu_path")
|
|
38
46
|
input_param["bench_json_path"] = input_param.pop("bench_path")
|
|
39
47
|
input_param["stack_json_path"] = input_param.pop("stack_path")
|
|
@@ -56,7 +64,16 @@ def compare_cli(args):
|
|
|
56
64
|
|
|
57
65
|
ms_compare(input_param, args.output_path, **kwargs)
|
|
58
66
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
59
|
-
kwargs = {
|
|
67
|
+
kwargs = {
|
|
68
|
+
"stack_mode": args.stack_mode,
|
|
69
|
+
"auto_analyze": auto_analyze,
|
|
70
|
+
"fuzzy_match": args.fuzzy_match,
|
|
71
|
+
"is_print_compare_log": input_param.get("is_print_compare_log", True),
|
|
72
|
+
"cell_mapping": args.cell_mapping,
|
|
73
|
+
"api_mapping": args.api_mapping,
|
|
74
|
+
"data_mapping": args.data_mapping,
|
|
75
|
+
"layer_mapping": args.layer_mapping
|
|
76
|
+
}
|
|
60
77
|
if input_param.get("rank_id") is not None:
|
|
61
78
|
ms_graph_compare(input_param, args.output_path)
|
|
62
79
|
return
|