mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
- mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
- msprobe/README.md +78 -23
- msprobe/__init__.py +1 -0
- msprobe/config/README.md +182 -40
- msprobe/config/config.json +22 -0
- msprobe/core/__init__.py +0 -0
- msprobe/{pytorch → core}/advisor/advisor.py +3 -3
- msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
- msprobe/core/common/const.py +82 -5
- msprobe/core/common/exceptions.py +30 -18
- msprobe/core/common/file_check.py +19 -1
- msprobe/core/common/log.py +15 -1
- msprobe/core/common/utils.py +130 -30
- msprobe/core/common_config.py +32 -19
- msprobe/core/compare/acc_compare.py +299 -0
- msprobe/core/compare/check.py +95 -0
- msprobe/core/compare/compare_cli.py +49 -0
- msprobe/core/compare/highlight.py +222 -0
- msprobe/core/compare/multiprocessing_compute.py +149 -0
- msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
- msprobe/core/compare/utils.py +429 -0
- msprobe/core/data_dump/data_collector.py +39 -35
- msprobe/core/data_dump/data_processor/base.py +85 -37
- msprobe/core/data_dump/data_processor/factory.py +5 -7
- msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
- msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
- msprobe/core/data_dump/json_writer.py +11 -11
- msprobe/core/grad_probe/__init__.py +0 -0
- msprobe/core/grad_probe/constant.py +71 -0
- msprobe/core/grad_probe/grad_compare.py +175 -0
- msprobe/core/grad_probe/utils.py +52 -0
- msprobe/doc/grad_probe/grad_probe.md +207 -0
- msprobe/doc/grad_probe/img/image-1.png +0 -0
- msprobe/doc/grad_probe/img/image-2.png +0 -0
- msprobe/doc/grad_probe/img/image-3.png +0 -0
- msprobe/doc/grad_probe/img/image-4.png +0 -0
- msprobe/doc/grad_probe/img/image.png +0 -0
- msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
- msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
- msprobe/mindspore/api_accuracy_checker/main.py +16 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
- msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
- msprobe/mindspore/cell_processor.py +34 -0
- msprobe/mindspore/common/const.py +87 -0
- msprobe/mindspore/common/log.py +38 -0
- msprobe/mindspore/common/utils.py +57 -0
- msprobe/mindspore/compare/distributed_compare.py +75 -0
- msprobe/mindspore/compare/ms_compare.py +117 -0
- msprobe/mindspore/compare/ms_graph_compare.py +317 -0
- msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
- msprobe/mindspore/debugger/debugger_config.py +38 -15
- msprobe/mindspore/debugger/precision_debugger.py +79 -4
- msprobe/mindspore/doc/compare.md +58 -0
- msprobe/mindspore/doc/dump.md +158 -6
- msprobe/mindspore/dump/dump_tool_factory.py +19 -22
- msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
- msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
- msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
- msprobe/mindspore/dump/jit_dump.py +56 -0
- msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
- msprobe/mindspore/free_benchmark/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
- msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/common/config.py +12 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
- msprobe/mindspore/free_benchmark/common/utils.py +71 -0
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
- msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
- msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
- msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
- msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
- msprobe/mindspore/grad_probe/__init__.py +0 -0
- msprobe/mindspore/grad_probe/global_context.py +91 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
- msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
- msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
- msprobe/mindspore/grad_probe/hook.py +92 -0
- msprobe/mindspore/grad_probe/utils.py +29 -0
- msprobe/mindspore/ms_config.py +63 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
- msprobe/mindspore/runtime.py +4 -0
- msprobe/mindspore/service.py +354 -0
- msprobe/mindspore/task_handler_factory.py +7 -4
- msprobe/msprobe.py +66 -26
- msprobe/pytorch/__init__.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
- msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
- msprobe/pytorch/bench_functions/__init__.py +15 -0
- msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
- msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
- msprobe/pytorch/bench_functions/linear.py +12 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
- msprobe/pytorch/bench_functions/swiglu.py +55 -0
- msprobe/pytorch/common/parse_json.py +3 -1
- msprobe/pytorch/common/utils.py +83 -7
- msprobe/pytorch/compare/distributed_compare.py +19 -64
- msprobe/pytorch/compare/match.py +3 -6
- msprobe/pytorch/compare/pt_compare.py +40 -0
- msprobe/pytorch/debugger/debugger_config.py +11 -2
- msprobe/pytorch/debugger/precision_debugger.py +34 -4
- msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
- msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
- msprobe/pytorch/doc/dump.md +73 -20
- msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
- msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
- msprobe/pytorch/doc/run_overflow_check.md +1 -1
- msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
- msprobe/pytorch/free_benchmark/common/constant.py +3 -0
- msprobe/pytorch/free_benchmark/common/utils.py +4 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
- msprobe/pytorch/function_factory.py +75 -0
- msprobe/pytorch/functional/dump_module.py +4 -4
- msprobe/pytorch/grad_probe/__init__.py +0 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
- msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
- msprobe/pytorch/hook_module/hook_module.py +14 -3
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/utils.py +9 -9
- msprobe/pytorch/hook_module/wrap_aten.py +20 -10
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
- msprobe/pytorch/hook_module/wrap_functional.py +4 -7
- msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
- msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
- msprobe/pytorch/hook_module/wrap_torch.py +5 -7
- msprobe/pytorch/hook_module/wrap_vf.py +6 -8
- msprobe/pytorch/module_processer.py +53 -13
- msprobe/pytorch/online_dispatch/compare.py +4 -4
- msprobe/pytorch/online_dispatch/dispatch.py +39 -41
- msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
- msprobe/pytorch/online_dispatch/single_compare.py +5 -5
- msprobe/pytorch/online_dispatch/utils.py +2 -43
- msprobe/pytorch/parse_tool/lib/compare.py +31 -19
- msprobe/pytorch/parse_tool/lib/config.py +2 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
- msprobe/pytorch/parse_tool/lib/utils.py +34 -80
- msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
- msprobe/pytorch/pt_config.py +100 -6
- msprobe/pytorch/service.py +104 -19
- mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
- msprobe/mindspore/dump/api_kbk_dump.py +0 -55
- msprobe/pytorch/compare/acc_compare.py +0 -1024
- msprobe/pytorch/compare/highlight.py +0 -100
- msprobe/test/core_ut/common/test_utils.py +0 -345
- msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
- msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
- msprobe/test/core_ut/data_dump/test_scope.py +0 -151
- msprobe/test/core_ut/test_common_config.py +0 -152
- msprobe/test/core_ut/test_file_check.py +0 -218
- msprobe/test/core_ut/test_log.py +0 -109
- msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
- msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
- msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
- msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
- msprobe/test/mindspore_ut/test_ms_config.py +0 -69
- msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
- msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
- msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
- msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
- msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
- msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
- msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
- msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
- msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
- msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
- msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
- msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
- msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
- msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
- msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
- msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
- msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
- msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
- msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
- msprobe/test/pytorch_ut/test_pt_config.py +0 -69
- msprobe/test/pytorch_ut/test_service.py +0 -59
- msprobe/test/resources/advisor.txt +0 -3
- msprobe/test/resources/compare_result_20230703104808.csv +0 -9
- msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
- msprobe/test/resources/config.yaml +0 -3
- msprobe/test/resources/npu_test.pkl +0 -8
- msprobe/test/run_test.sh +0 -30
- msprobe/test/run_ut.py +0 -58
- msprobe/test/test_module_processer.py +0 -64
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
- /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
|
@@ -1,1024 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# -*- coding: utf-8 -*-
|
|
3
|
-
"""
|
|
4
|
-
# Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
|
|
5
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
-
# you may not use this file except in compliance with the License.
|
|
7
|
-
# You may obtain a copy of the License at
|
|
8
|
-
#
|
|
9
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
-
#
|
|
11
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
-
# See the License for the specific language governing permissions and
|
|
15
|
-
# limitations under the License.
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
import json
|
|
19
|
-
import multiprocessing
|
|
20
|
-
import os.path
|
|
21
|
-
import sys
|
|
22
|
-
import torch
|
|
23
|
-
import numpy as np
|
|
24
|
-
import pandas as pd
|
|
25
|
-
import openpyxl
|
|
26
|
-
from openpyxl.styles import PatternFill
|
|
27
|
-
from collections import namedtuple
|
|
28
|
-
from dataclasses import dataclass
|
|
29
|
-
|
|
30
|
-
from msprobe.pytorch.compare.match import graph_mapping
|
|
31
|
-
from msprobe.pytorch.compare.highlight import HighlightRules, get_header_index
|
|
32
|
-
from msprobe.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
|
|
33
|
-
get_error_message
|
|
34
|
-
from msprobe.pytorch.advisor.advisor import Advisor
|
|
35
|
-
from msprobe.pytorch.common.log import logger
|
|
36
|
-
from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \
|
|
37
|
-
format_value, check_file_not_exists, check_configuration_param, task_dumppath_get
|
|
38
|
-
from msprobe.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory
|
|
39
|
-
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def check_graph_mode(a_op_name, b_op_name):
|
|
43
|
-
if "Aten" in a_op_name and "Aten" not in b_op_name:
|
|
44
|
-
return True
|
|
45
|
-
if "Aten" not in a_op_name and "Aten" in b_op_name:
|
|
46
|
-
return True
|
|
47
|
-
return False
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def check_op(npu_dict, bench_dict, fuzzy_match):
|
|
51
|
-
a_op_name = npu_dict["op_name"]
|
|
52
|
-
b_op_name = bench_dict["op_name"]
|
|
53
|
-
graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
|
|
54
|
-
if graph_mode:
|
|
55
|
-
return graph_mapping.match(a_op_name[0], b_op_name[0])
|
|
56
|
-
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
57
|
-
if not fuzzy_match:
|
|
58
|
-
return a_op_name == b_op_name and struct_match
|
|
59
|
-
is_match = True
|
|
60
|
-
try:
|
|
61
|
-
is_match = fuzzy_check_op(a_op_name, b_op_name)
|
|
62
|
-
except Exception as err:
|
|
63
|
-
logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name))
|
|
64
|
-
is_match = False
|
|
65
|
-
return is_match and struct_match
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def check_struct_match(npu_dict, bench_dict):
|
|
69
|
-
npu_struct_in = npu_dict.get("input_struct")
|
|
70
|
-
bench_struct_in = bench_dict.get("input_struct")
|
|
71
|
-
npu_struct_out = npu_dict.get("output_struct")
|
|
72
|
-
bench_struct_out = bench_dict.get("output_struct")
|
|
73
|
-
is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
|
|
74
|
-
if not is_match:
|
|
75
|
-
if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
|
|
76
|
-
return False
|
|
77
|
-
struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
|
|
78
|
-
struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
|
|
79
|
-
is_match = struct_in_is_match and struct_out_is_match
|
|
80
|
-
return is_match
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def check_type_shape_match(npu_struct, bench_struct):
|
|
84
|
-
shape_type_match = False
|
|
85
|
-
for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
|
|
86
|
-
npu_type = npu_type_shape[0]
|
|
87
|
-
npu_shape = npu_type_shape[1]
|
|
88
|
-
bench_type = bench_type_shape[0]
|
|
89
|
-
bench_shape = bench_type_shape[1]
|
|
90
|
-
shape_match = npu_shape == bench_shape
|
|
91
|
-
type_match = npu_type == bench_type
|
|
92
|
-
if not type_match:
|
|
93
|
-
if [npu_type, bench_type] in [["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
|
|
94
|
-
["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]:
|
|
95
|
-
type_match = True
|
|
96
|
-
else:
|
|
97
|
-
type_match = False
|
|
98
|
-
shape_type_match = shape_match and type_match
|
|
99
|
-
if not shape_type_match:
|
|
100
|
-
return False
|
|
101
|
-
return shape_type_match
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def fuzzy_check_op(npu_name_list, bench_name_list):
|
|
105
|
-
if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
|
|
106
|
-
return False
|
|
107
|
-
is_match = True
|
|
108
|
-
for npu_name, bench_name in zip(npu_name_list, bench_name_list):
|
|
109
|
-
is_match = fuzzy_check_name(npu_name, bench_name)
|
|
110
|
-
if not is_match:
|
|
111
|
-
break
|
|
112
|
-
return is_match
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def fuzzy_check_name(npu_name, bench_name):
|
|
116
|
-
if "forward" in npu_name and "forward" in bench_name:
|
|
117
|
-
is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
|
|
118
|
-
elif "backward" in npu_name and "backward" in bench_name:
|
|
119
|
-
is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
|
|
120
|
-
else:
|
|
121
|
-
is_match = npu_name == bench_name
|
|
122
|
-
return is_match
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def rename_api(npu_name, process):
|
|
126
|
-
npu_split = npu_name.split(process)
|
|
127
|
-
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
128
|
-
torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
|
|
129
|
-
torch_func = str(torch_func_split[0]) + str(in_out)
|
|
130
|
-
return torch_func
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def merge_tensor(tensor_list, summary_compare, md5_compare):
|
|
134
|
-
op_dict = {}
|
|
135
|
-
op_dict["op_name"] = []
|
|
136
|
-
op_dict["input_struct"] = []
|
|
137
|
-
op_dict["kwargs_struct"] = []
|
|
138
|
-
op_dict["output_struct"] = []
|
|
139
|
-
op_dict["summary"] = []
|
|
140
|
-
op_dict["stack_info"] = []
|
|
141
|
-
|
|
142
|
-
all_mode_bool = not (summary_compare or md5_compare)
|
|
143
|
-
if all_mode_bool:
|
|
144
|
-
op_dict["data_name"] = []
|
|
145
|
-
|
|
146
|
-
for tensor in tensor_list:
|
|
147
|
-
if len(tensor) == 2:
|
|
148
|
-
op_dict['stack_info'].append(tensor['full_info'])
|
|
149
|
-
break
|
|
150
|
-
op_dict["op_name"].append(tensor['full_op_name'])
|
|
151
|
-
if not md5_compare:
|
|
152
|
-
if tensor['full_op_name'].find("input") != -1:
|
|
153
|
-
op_dict["input_struct"].append((tensor['dtype'], tensor['shape']))
|
|
154
|
-
elif tensor['full_op_name'].find("kwarg") != -1:
|
|
155
|
-
op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape']))
|
|
156
|
-
elif tensor['full_op_name'].find("output") != -1:
|
|
157
|
-
op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
|
|
158
|
-
else:
|
|
159
|
-
if tensor['full_op_name'].find("input") != -1:
|
|
160
|
-
op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
161
|
-
elif tensor['full_op_name'].find("kwarg") != -1:
|
|
162
|
-
op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
163
|
-
elif tensor['full_op_name'].find("output") != -1:
|
|
164
|
-
op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
165
|
-
|
|
166
|
-
op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
|
|
167
|
-
|
|
168
|
-
if all_mode_bool:
|
|
169
|
-
op_dict["data_name"].append(tensor['data_name'])
|
|
170
|
-
|
|
171
|
-
if not op_dict["kwargs_struct"]:
|
|
172
|
-
del op_dict["kwargs_struct"]
|
|
173
|
-
return op_dict if op_dict["op_name"] else {}
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
def match_op(npu_queue, bench_queue, fuzzy_match):
|
|
177
|
-
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
178
|
-
if check_op(npu_queue[-1], b_op, fuzzy_match):
|
|
179
|
-
return len(npu_queue) - 1, b_index
|
|
180
|
-
if check_op(npu_queue[-1], bench_queue[-1], fuzzy_match):
|
|
181
|
-
return len(npu_queue) - 1, len(bench_queue) - 1
|
|
182
|
-
for n_index, n_op in enumerate(npu_queue[0: -1]):
|
|
183
|
-
if check_op(n_op, bench_queue[-1], fuzzy_match):
|
|
184
|
-
return n_index, len(bench_queue) - 1
|
|
185
|
-
return -1, -1
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False):
|
|
189
|
-
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
190
|
-
min_len = min(n_len, b_len)
|
|
191
|
-
npu_stack_info = n_dict.get("stack_info", None)
|
|
192
|
-
bench_stack_info = b_dict.get("stack_info", None)
|
|
193
|
-
has_stack = npu_stack_info and bench_stack_info
|
|
194
|
-
|
|
195
|
-
all_mode_bool = not (summary_compare or md5_compare)
|
|
196
|
-
if all_mode_bool:
|
|
197
|
-
npu_data_name = n_dict.get("data_name", None)
|
|
198
|
-
bench_data_name = b_dict.get("data_name", None)
|
|
199
|
-
|
|
200
|
-
for index in range(min_len):
|
|
201
|
-
|
|
202
|
-
n_name = n_dict['op_name'][n_start + index]
|
|
203
|
-
b_name = b_dict['op_name'][b_start + index]
|
|
204
|
-
n_struct = n_dict[key][index]
|
|
205
|
-
b_struct = b_dict[key][index]
|
|
206
|
-
err_msg = ""
|
|
207
|
-
if md5_compare:
|
|
208
|
-
result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
209
|
-
n_struct[2], b_struct[2],
|
|
210
|
-
CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF]
|
|
211
|
-
if has_stack and index == 0 and key == "input_struct":
|
|
212
|
-
result_item.extend(npu_stack_info)
|
|
213
|
-
else:
|
|
214
|
-
result_item.append(CompareConst.NONE)
|
|
215
|
-
result.append(result_item)
|
|
216
|
-
continue
|
|
217
|
-
|
|
218
|
-
if summary_compare:
|
|
219
|
-
result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
220
|
-
" ", " ", " ", " ", " ", " ", " ", " "]
|
|
221
|
-
else:
|
|
222
|
-
result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1],
|
|
223
|
-
" ", " ", " ", " ", " "]
|
|
224
|
-
|
|
225
|
-
npu_summary_data = n_dict.get("summary")[n_start + index]
|
|
226
|
-
result_item.extend(npu_summary_data)
|
|
227
|
-
bench_summary_data = b_dict.get("summary")[b_start + index]
|
|
228
|
-
result_item.extend(bench_summary_data)
|
|
229
|
-
|
|
230
|
-
if summary_compare:
|
|
231
|
-
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
232
|
-
warning_flag = False
|
|
233
|
-
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
234
|
-
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
235
|
-
diff = npu_val - bench_val
|
|
236
|
-
if bench_val != 0:
|
|
237
|
-
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
238
|
-
else:
|
|
239
|
-
relative = "N/A"
|
|
240
|
-
result_item[start_idx + i] = diff
|
|
241
|
-
result_item[start_idx + i + 4] = relative
|
|
242
|
-
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
243
|
-
if magnitude_diff > 0.5:
|
|
244
|
-
warning_flag = True
|
|
245
|
-
else:
|
|
246
|
-
result_item[start_idx + i] = CompareConst.NONE
|
|
247
|
-
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
248
|
-
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
249
|
-
for i in range(start_idx, len(result_item)):
|
|
250
|
-
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
251
|
-
result_item[i] = f'{result_item[i]}\t'
|
|
252
|
-
|
|
253
|
-
result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
|
|
254
|
-
result_item.append(err_msg)
|
|
255
|
-
if has_stack and index == 0 and key == "input_struct":
|
|
256
|
-
result_item.extend(npu_stack_info)
|
|
257
|
-
else:
|
|
258
|
-
result_item.append(CompareConst.NONE)
|
|
259
|
-
if all_mode_bool:
|
|
260
|
-
result_item.append(npu_data_name[n_start + index])
|
|
261
|
-
|
|
262
|
-
result.append(result_item)
|
|
263
|
-
|
|
264
|
-
if n_len > b_len:
|
|
265
|
-
for index in range(b_len, n_len):
|
|
266
|
-
n_name = n_dict['op_name'][n_start + index]
|
|
267
|
-
n_struct = n_dict[key][index]
|
|
268
|
-
if md5_compare:
|
|
269
|
-
result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
|
|
270
|
-
n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN]
|
|
271
|
-
result.append(result_item)
|
|
272
|
-
continue
|
|
273
|
-
result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN,
|
|
274
|
-
n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "]
|
|
275
|
-
summary_data = n_dict.get("summary")[n_start + index]
|
|
276
|
-
result_item.extend(summary_data)
|
|
277
|
-
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))]
|
|
278
|
-
result_item.extend(summary_data)
|
|
279
|
-
|
|
280
|
-
err_msg = ""
|
|
281
|
-
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
282
|
-
result_item.append(err_msg)
|
|
283
|
-
|
|
284
|
-
if has_stack and index == 0 and key == "input_struct":
|
|
285
|
-
result_item.extend(npu_stack_info)
|
|
286
|
-
else:
|
|
287
|
-
result_item.append(CompareConst.NONE)
|
|
288
|
-
if all_mode_bool:
|
|
289
|
-
result_item.append(npu_data_name[n_start + index])
|
|
290
|
-
|
|
291
|
-
result.append(result_item)
|
|
292
|
-
|
|
293
|
-
n_num = len(n_dict['op_name'])
|
|
294
|
-
b_num = len(b_dict['op_name'])
|
|
295
|
-
n_num_input = len([name for name in n_dict['op_name'] if 'input' in name])
|
|
296
|
-
b_num_input = len([name for name in b_dict['op_name'] if 'input' in name])
|
|
297
|
-
n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name])
|
|
298
|
-
b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name])
|
|
299
|
-
n_num_output = n_num - n_num_input - n_num_kwarg
|
|
300
|
-
b_num_output = b_num - b_num_input - b_num_kwarg
|
|
301
|
-
get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
|
|
302
|
-
get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct")
|
|
303
|
-
get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _do_multi_process(input_parma, result_df):
|
|
307
|
-
try:
|
|
308
|
-
result_df = _handle_multi_process(compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
|
|
309
|
-
return result_df
|
|
310
|
-
except ValueError as e:
|
|
311
|
-
logger.error('result dataframe is not found.')
|
|
312
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
def read_dump_data(result_df):
|
|
316
|
-
try:
|
|
317
|
-
npu_dump_name_list = result_df.iloc[0:, 0].tolist()
|
|
318
|
-
npu_dump_tensor_list = result_df.iloc[0:, -1].tolist()
|
|
319
|
-
op_name_mapping_dict = {}
|
|
320
|
-
for index, _ in enumerate(npu_dump_name_list):
|
|
321
|
-
npu_dump_name = npu_dump_name_list[index]
|
|
322
|
-
npu_dump_tensor = npu_dump_tensor_list[index]
|
|
323
|
-
op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor]
|
|
324
|
-
return op_name_mapping_dict
|
|
325
|
-
except ValueError as e:
|
|
326
|
-
logger.error('result dataframe is not found.')
|
|
327
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
328
|
-
except IndexError as e:
|
|
329
|
-
logger.error('result dataframe elements can not be access.')
|
|
330
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
def _handle_multi_process(func, input_parma, result_df, lock):
|
|
334
|
-
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
335
|
-
op_name_mapping_dict = read_dump_data(result_df)
|
|
336
|
-
|
|
337
|
-
df_chunk_size = len(result_df) // process_num
|
|
338
|
-
if df_chunk_size > 0:
|
|
339
|
-
df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)]
|
|
340
|
-
else:
|
|
341
|
-
df_chunks = [result_df]
|
|
342
|
-
|
|
343
|
-
results = []
|
|
344
|
-
pool = multiprocessing.Pool(process_num)
|
|
345
|
-
|
|
346
|
-
def err_call(args):
|
|
347
|
-
logger.error('multiprocess compare failed! Reason: {}'.format(args))
|
|
348
|
-
try:
|
|
349
|
-
pool.terminate()
|
|
350
|
-
except OSError as e:
|
|
351
|
-
logger.error("pool terminate failed")
|
|
352
|
-
|
|
353
|
-
for process_idx, df_chunk in enumerate(df_chunks):
|
|
354
|
-
idx = df_chunk_size * process_idx
|
|
355
|
-
result = pool.apply_async(func,
|
|
356
|
-
args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma),
|
|
357
|
-
error_callback=err_call)
|
|
358
|
-
results.append(result)
|
|
359
|
-
final_results = [r.get() for r in results]
|
|
360
|
-
pool.close()
|
|
361
|
-
pool.join()
|
|
362
|
-
return pd.concat(final_results, ignore_index=True)
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
def compare_ops(idx, dump_path_dict, result_df, lock, input_parma):
|
|
366
|
-
cos_result = []
|
|
367
|
-
max_err_result = []
|
|
368
|
-
max_relative_err_result = []
|
|
369
|
-
err_mess = []
|
|
370
|
-
one_thousand_err_ratio_result = []
|
|
371
|
-
five_thousand_err_ratio_result = []
|
|
372
|
-
is_print_compare_log = input_parma.get("is_print_compare_log")
|
|
373
|
-
for i in range(len(result_df)):
|
|
374
|
-
op_name = result_df.iloc[i, 0]
|
|
375
|
-
if is_print_compare_log:
|
|
376
|
-
logger.info("start compare: {}".format(op_name))
|
|
377
|
-
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = compare_by_op(
|
|
378
|
-
op_name, dump_path_dict, input_parma)
|
|
379
|
-
if is_print_compare_log:
|
|
380
|
-
logger.info(
|
|
381
|
-
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
|
|
382
|
-
"five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
|
|
383
|
-
one_thousand_err_ratio, five_thousand_err_ratio))
|
|
384
|
-
cos_result.append(cos_sim)
|
|
385
|
-
max_err_result.append(max_abs_err)
|
|
386
|
-
max_relative_err_result.append(max_relative_err)
|
|
387
|
-
err_mess.append(err_msg)
|
|
388
|
-
one_thousand_err_ratio_result.append(one_thousand_err_ratio)
|
|
389
|
-
five_thousand_err_ratio_result.append(five_thousand_err_ratio)
|
|
390
|
-
|
|
391
|
-
cr = ComparisonResult(
|
|
392
|
-
cos_result=cos_result,
|
|
393
|
-
max_err_result=max_err_result,
|
|
394
|
-
max_relative_err_result=max_relative_err_result,
|
|
395
|
-
err_msgs=err_mess,
|
|
396
|
-
one_thousand_err_ratio_result=one_thousand_err_ratio_result,
|
|
397
|
-
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
return _save_cmp_result(idx, cr, result_df, lock)
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
@dataclass
|
|
404
|
-
class ComparisonResult:
|
|
405
|
-
cos_result: list
|
|
406
|
-
max_err_result: list
|
|
407
|
-
max_relative_err_result: list
|
|
408
|
-
err_msgs: list
|
|
409
|
-
one_thousand_err_ratio_result: list
|
|
410
|
-
five_thousand_err_ratio_result: list
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
def _save_cmp_result(offset, result: ComparisonResult, result_df, lock):
|
|
414
|
-
"""
|
|
415
|
-
Save comparison results into the result DataFrame with thread safety.
|
|
416
|
-
Args:
|
|
417
|
-
offset: offset for index
|
|
418
|
-
result: data struct of ComparisonResult
|
|
419
|
-
result_df: result of DataFrame
|
|
420
|
-
lock: thread lock
|
|
421
|
-
|
|
422
|
-
Returns:
|
|
423
|
-
comparison results in DataFrame
|
|
424
|
-
"""
|
|
425
|
-
|
|
426
|
-
lock.acquire()
|
|
427
|
-
try:
|
|
428
|
-
for i, _ in enumerate(result.cos_result):
|
|
429
|
-
process_index = i + offset
|
|
430
|
-
result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i]
|
|
431
|
-
result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i]
|
|
432
|
-
result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i]
|
|
433
|
-
result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i]
|
|
434
|
-
result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i])
|
|
435
|
-
result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i]
|
|
436
|
-
result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i]
|
|
437
|
-
return result_df
|
|
438
|
-
except ValueError as e:
|
|
439
|
-
logger.error('result dataframe is not found.')
|
|
440
|
-
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
441
|
-
except IndexError as e:
|
|
442
|
-
logger.error('result dataframe elements can not be access.')
|
|
443
|
-
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
444
|
-
finally:
|
|
445
|
-
lock.release()
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def check_accuracy(cos, max_abs_err):
|
|
449
|
-
if cos == CompareConst.SHAPE_UNMATCH:
|
|
450
|
-
return CompareConst.ACCURACY_CHECK_UNMATCH
|
|
451
|
-
if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE:
|
|
452
|
-
return CompareConst.NONE
|
|
453
|
-
if cos == "N/A" or max_abs_err == "N/A":
|
|
454
|
-
return CompareConst.ACCURACY_CHECK_NO
|
|
455
|
-
try:
|
|
456
|
-
cos, max_abs_err = float(cos), float(max_abs_err)
|
|
457
|
-
except ValueError:
|
|
458
|
-
logger.warning("Cosine or MaxAbsErr can not get float value.")
|
|
459
|
-
return CompareConst.NONE
|
|
460
|
-
if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD:
|
|
461
|
-
return CompareConst.ACCURACY_CHECK_NO
|
|
462
|
-
if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD:
|
|
463
|
-
return CompareConst.ACCURACY_CHECK_NO
|
|
464
|
-
return CompareConst.ACCURACY_CHECK_YES
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
def read_npy_data(dir_path, file_name):
|
|
468
|
-
data_path = os.path.join(dir_path, file_name)
|
|
469
|
-
path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE,
|
|
470
|
-
FileCheckConst.PT_SUFFIX, False)
|
|
471
|
-
data_path = path_checker.common_check()
|
|
472
|
-
data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory
|
|
473
|
-
if data_value.dtype == torch.bfloat16:
|
|
474
|
-
data_value = data_value.to(torch.float32)
|
|
475
|
-
data_value = data_value.numpy()
|
|
476
|
-
return data_value
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
def compare_by_op(op_name, op_name_mapping_dict, input_parma):
|
|
480
|
-
npu_bench_name_list = op_name_mapping_dict[op_name]
|
|
481
|
-
data_name = npu_bench_name_list[1]
|
|
482
|
-
error_file, relative_err, error_flag = None, None, False
|
|
483
|
-
if data_name == '-1' or data_name == -1: # 没有真实数据路径
|
|
484
|
-
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
485
|
-
error_flag = True
|
|
486
|
-
else:
|
|
487
|
-
try:
|
|
488
|
-
n_value = read_npy_data(input_parma.get("npu_dump_data_dir"), npu_bench_name_list[0])
|
|
489
|
-
b_value = read_npy_data(input_parma.get("bench_dump_data_dir"), npu_bench_name_list[1])
|
|
490
|
-
except IOError as error:
|
|
491
|
-
error_file = error.filename
|
|
492
|
-
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
493
|
-
error_flag = True
|
|
494
|
-
|
|
495
|
-
n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag)
|
|
496
|
-
if not error_flag:
|
|
497
|
-
relative_err = get_relative_err(n_value, b_value)
|
|
498
|
-
n_value, b_value = reshape_value(n_value, b_value)
|
|
499
|
-
|
|
500
|
-
err_msg = get_error_message(n_value, b_value, op_name, error_flag, error_file=error_file)
|
|
501
|
-
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
|
|
502
|
-
|
|
503
|
-
if npu_bench_name_list[0] != npu_bench_name_list[1]:
|
|
504
|
-
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
505
|
-
result_list.append(err_msg)
|
|
506
|
-
return result_list
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
def handle_inf_nan(n_value, b_value):
|
|
510
|
-
n_inf = np.isinf(n_value)
|
|
511
|
-
b_inf = np.isinf(b_value)
|
|
512
|
-
n_nan = np.isnan(n_value)
|
|
513
|
-
b_nan = np.isnan(b_value)
|
|
514
|
-
|
|
515
|
-
# merge boolean expressions
|
|
516
|
-
any_inf = np.any(n_inf) or np.any(b_inf)
|
|
517
|
-
any_nan = np.any(n_nan) or np.any(b_nan)
|
|
518
|
-
if any_inf or any_nan:
|
|
519
|
-
if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan):
|
|
520
|
-
n_value[n_inf] = 0
|
|
521
|
-
b_value[b_inf] = 0
|
|
522
|
-
n_value[n_nan] = 0
|
|
523
|
-
b_value[b_nan] = 0
|
|
524
|
-
else:
|
|
525
|
-
return CompareConst.NAN, CompareConst.NAN
|
|
526
|
-
return n_value, b_value
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
|
|
530
|
-
"""找到单个API中需要高亮的行"""
|
|
531
|
-
if md5_compare:
|
|
532
|
-
return
|
|
533
|
-
npu_max_index = get_header_index('NPU max', summary_compare)
|
|
534
|
-
bench_max_index = get_header_index('Bench max', summary_compare)
|
|
535
|
-
max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
|
|
536
|
-
|
|
537
|
-
red_lines, yellow_lines = [], []
|
|
538
|
-
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
539
|
-
ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
|
|
540
|
-
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
541
|
-
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
542
|
-
|
|
543
|
-
# 对单行API的输入或输出进行误差判断
|
|
544
|
-
for i, line in enumerate(result):
|
|
545
|
-
num = last_len + i
|
|
546
|
-
line_info = LineInfo(line_data=line, num_pointer=num)
|
|
547
|
-
for rule in HighlightRules.basic_rules.values():
|
|
548
|
-
rule.apply(line_info, color_columns, summary_compare)
|
|
549
|
-
|
|
550
|
-
# 对API的输出与输入比较,进行误差判断
|
|
551
|
-
for n, api_out in enumerate(result[n_num_input:len(result)]):
|
|
552
|
-
num = last_len + n_num_input + n
|
|
553
|
-
if num in red_lines:
|
|
554
|
-
continue
|
|
555
|
-
if not isinstance(api_out[npu_max_index], (float, int)) \
|
|
556
|
-
or not isinstance(api_out[bench_max_index], (float, int)) \
|
|
557
|
-
or not isinstance(api_out[max_diff_index], (float, int)):
|
|
558
|
-
continue
|
|
559
|
-
for _, api_in in enumerate(result[0:n_num_input]):
|
|
560
|
-
if not isinstance(api_in[npu_max_index], (float, int)) \
|
|
561
|
-
or not isinstance(api_in[bench_max_index], (float, int)) \
|
|
562
|
-
or not isinstance(api_in[max_diff_index], (float, int)):
|
|
563
|
-
continue
|
|
564
|
-
|
|
565
|
-
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
|
|
566
|
-
if summary_compare:
|
|
567
|
-
for rule in HighlightRules.summary_compare_rules.values():
|
|
568
|
-
rule.apply(api_info, color_columns, summary_compare)
|
|
569
|
-
else:
|
|
570
|
-
for rule in HighlightRules.compare_rules.values():
|
|
571
|
-
rule.apply(api_info, color_columns, summary_compare)
|
|
572
|
-
|
|
573
|
-
highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
|
|
574
|
-
highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
def get_name_and_state(name):
|
|
578
|
-
"""Get api/module name and state"""
|
|
579
|
-
if "input" in name:
|
|
580
|
-
api_name = name.split("input")[0]
|
|
581
|
-
state = "input"
|
|
582
|
-
else:
|
|
583
|
-
api_name = name.split("output")[0]
|
|
584
|
-
state = "output"
|
|
585
|
-
return api_name, state
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
|
|
589
|
-
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
590
|
-
result = result_df.values
|
|
591
|
-
start, input_num, output_num, end = 0, 0, 0, len(result_df)
|
|
592
|
-
last_api_name, last_state = None, None
|
|
593
|
-
num, last_len = 0, 0
|
|
594
|
-
for res_i in result:
|
|
595
|
-
api_name, state = get_name_and_state(res_i[0])
|
|
596
|
-
if last_api_name:
|
|
597
|
-
if api_name == last_api_name:
|
|
598
|
-
if state == last_state:
|
|
599
|
-
num += 1
|
|
600
|
-
else:
|
|
601
|
-
input_num = num
|
|
602
|
-
num, last_state = 1, state
|
|
603
|
-
else:
|
|
604
|
-
output_num = num
|
|
605
|
-
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
606
|
-
summary_compare, md5_compare)
|
|
607
|
-
num, last_api_name, last_state = 1, api_name, state
|
|
608
|
-
start += input_num + output_num
|
|
609
|
-
input_num, output_num = 1, 0
|
|
610
|
-
else:
|
|
611
|
-
num, last_api_name, last_state = 1, api_name, state
|
|
612
|
-
if state:
|
|
613
|
-
if state == "input":
|
|
614
|
-
input_num = num
|
|
615
|
-
else:
|
|
616
|
-
output_num = num
|
|
617
|
-
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
621
|
-
"""Write and highlight results in Excel"""
|
|
622
|
-
logger.info('Compare result is %s' % file_path)
|
|
623
|
-
|
|
624
|
-
wb = openpyxl.Workbook()
|
|
625
|
-
ws = wb.active
|
|
626
|
-
|
|
627
|
-
# write header
|
|
628
|
-
for j, col_name in enumerate(result_df.columns, start=1):
|
|
629
|
-
ws.cell(row=1, column=j, value=col_name)
|
|
630
|
-
|
|
631
|
-
for i, row in enumerate(result_df.iterrows(), start=2):
|
|
632
|
-
for j, value in enumerate(row[1], start=1):
|
|
633
|
-
if not isinstance(value, (float, int)):
|
|
634
|
-
value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
|
|
635
|
-
ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
|
|
636
|
-
|
|
637
|
-
if (i - 2) in highlight_dict['red_rows']:
|
|
638
|
-
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
|
|
639
|
-
end_color=CompareConst.RED, fill_type="solid")
|
|
640
|
-
elif (i - 2) in highlight_dict['yellow_rows']:
|
|
641
|
-
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
|
|
642
|
-
end_color=CompareConst.YELLOW, fill_type="solid")
|
|
643
|
-
wb.save(file_path)
|
|
644
|
-
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
def compare(input_parma, output_path, stack_mode=False, auto_analyze=True,
|
|
648
|
-
fuzzy_match=False):
|
|
649
|
-
try:
|
|
650
|
-
summary_compare, md5_compare = task_dumppath_get(input_parma)
|
|
651
|
-
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
|
|
652
|
-
create_directory(output_path)
|
|
653
|
-
check_compare_param(input_parma, output_path, stack_mode, summary_compare, md5_compare)
|
|
654
|
-
except CompareException as error:
|
|
655
|
-
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
656
|
-
sys.exit(error.code)
|
|
657
|
-
compare_core(input_parma, output_path, stack_mode=stack_mode,
|
|
658
|
-
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
659
|
-
md5_compare=md5_compare)
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
def compare_core(input_parma, output_path, **kwargs):
|
|
663
|
-
"""
|
|
664
|
-
Compares data from multiple JSON files and generates a comparison report.
|
|
665
|
-
|
|
666
|
-
Args:
|
|
667
|
-
input_parma (dict): A dictionary containing paths to JSON files ("npu_json_path", "bench_json_path",
|
|
668
|
-
"stack_json_path").
|
|
669
|
-
output_path (str): The path where the output Excel report will be saved.
|
|
670
|
-
**kwargs: Additional keyword arguments including:
|
|
671
|
-
- stack_mode (bool, optional): Enables stack mode comparison. Defaults to False.
|
|
672
|
-
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
673
|
-
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
674
|
-
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
675
|
-
- summary_compare (bool, optional): Enables summary comparison mode. Defaults to False.
|
|
676
|
-
- md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
|
|
677
|
-
|
|
678
|
-
Returns:
|
|
679
|
-
"""
|
|
680
|
-
# get kwargs or set default value
|
|
681
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
682
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
683
|
-
suffix = kwargs.get('suffix', '')
|
|
684
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
685
|
-
summary_compare = kwargs.get('summary_compare', False)
|
|
686
|
-
md5_compare = kwargs.get('md5_compare', False)
|
|
687
|
-
|
|
688
|
-
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
689
|
-
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
690
|
-
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
691
|
-
check_file_not_exists(file_path)
|
|
692
|
-
highlight_dict = {'red_rows': [], 'yellow_rows': []}
|
|
693
|
-
|
|
694
|
-
with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
|
|
695
|
-
FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
|
|
696
|
-
FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
|
|
697
|
-
result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
|
|
698
|
-
summary_compare, md5_compare)
|
|
699
|
-
|
|
700
|
-
if not md5_compare and not summary_compare:
|
|
701
|
-
result_df = _do_multi_process(input_parma, result_df)
|
|
702
|
-
find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
|
|
703
|
-
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
704
|
-
if auto_analyze:
|
|
705
|
-
advisor = Advisor(result_df, output_path)
|
|
706
|
-
advisor.analysis()
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
def parse(pkl_file, module_name_prefix):
|
|
710
|
-
if not isinstance(module_name_prefix, str):
|
|
711
|
-
logger.error("The parameter:module_name_prefix is not a string.")
|
|
712
|
-
raise CompareException(CompareException.INVALID_PARAM_ERROR)
|
|
713
|
-
with FileOpen(pkl_file, "r") as f:
|
|
714
|
-
done = False
|
|
715
|
-
title_printed = False
|
|
716
|
-
while not done:
|
|
717
|
-
pkl_line = f.readline()
|
|
718
|
-
if pkl_line == '\n':
|
|
719
|
-
continue
|
|
720
|
-
if len(pkl_line) == 0:
|
|
721
|
-
done = True
|
|
722
|
-
break
|
|
723
|
-
|
|
724
|
-
msg = json.loads(pkl_line)
|
|
725
|
-
info_prefix = msg[0]
|
|
726
|
-
if not info_prefix.startswith(module_name_prefix):
|
|
727
|
-
continue
|
|
728
|
-
|
|
729
|
-
if info_prefix.find("stack_info") != -1:
|
|
730
|
-
logger.info("\nTrace back({}):".format(msg[0]))
|
|
731
|
-
for item in reversed(msg[1]):
|
|
732
|
-
logger.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2]))
|
|
733
|
-
logger.info(" {}".format(item[3]))
|
|
734
|
-
continue
|
|
735
|
-
if len(msg) > 5:
|
|
736
|
-
summary_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \
|
|
737
|
-
.format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2])
|
|
738
|
-
if not title_printed:
|
|
739
|
-
logger.info("\nStatistic Info:")
|
|
740
|
-
title_printed = True
|
|
741
|
-
logger.info(summary_info)
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
|
|
745
|
-
if item_list is None:
|
|
746
|
-
item_list = []
|
|
747
|
-
if item is None or (isinstance(item, dict) and not item):
|
|
748
|
-
if not top_bool:
|
|
749
|
-
tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None,
|
|
750
|
-
'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'}
|
|
751
|
-
else:
|
|
752
|
-
tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None,
|
|
753
|
-
'shape': None, 'md5': None, 'data_name': '-1'}
|
|
754
|
-
item_list.append(tmp)
|
|
755
|
-
return item_list
|
|
756
|
-
if index is None:
|
|
757
|
-
if isinstance(item, dict):
|
|
758
|
-
full_op_name = op_name + '.0'
|
|
759
|
-
else:
|
|
760
|
-
full_op_name = op_name
|
|
761
|
-
else:
|
|
762
|
-
full_op_name = op_name + '.' + str(index)
|
|
763
|
-
if isinstance(item, dict):
|
|
764
|
-
if 'dtype' in item:
|
|
765
|
-
parsed_item = item
|
|
766
|
-
parsed_item['full_op_name'] = full_op_name
|
|
767
|
-
item_list.append(parsed_item)
|
|
768
|
-
elif 'type' in item:
|
|
769
|
-
parsed_item = {}
|
|
770
|
-
if item['type'] == 'torch.Size':
|
|
771
|
-
parsed_item['full_op_name'] = full_op_name
|
|
772
|
-
parsed_item['dtype'] = 'torch.Size'
|
|
773
|
-
parsed_item['shape'] = str(item['value'])
|
|
774
|
-
parsed_item['md5'] = None
|
|
775
|
-
parsed_item['Max'] = None
|
|
776
|
-
parsed_item['Min'] = None
|
|
777
|
-
parsed_item['Mean'] = None
|
|
778
|
-
parsed_item['Norm'] = None
|
|
779
|
-
parsed_item['data_name'] = '-1'
|
|
780
|
-
item_list.append(parsed_item)
|
|
781
|
-
elif item['type'] == 'slice':
|
|
782
|
-
parsed_item['full_op_name'] = full_op_name
|
|
783
|
-
parsed_item['dtype'] = 'slice'
|
|
784
|
-
parsed_item['shape'] = str(np.shape(np.array(item['value'])))
|
|
785
|
-
parsed_item['md5'] = None
|
|
786
|
-
parsed_item['Max'] = None
|
|
787
|
-
parsed_item['Min'] = None
|
|
788
|
-
parsed_item['Mean'] = None
|
|
789
|
-
parsed_item['Norm'] = None
|
|
790
|
-
parsed_item['data_name'] = '-1'
|
|
791
|
-
item_list.append(parsed_item)
|
|
792
|
-
else:
|
|
793
|
-
parsed_item['full_op_name'] = full_op_name
|
|
794
|
-
parsed_item['dtype'] = str(type(item['value']))
|
|
795
|
-
parsed_item['shape'] = '[]'
|
|
796
|
-
parsed_item['md5'] = None
|
|
797
|
-
parsed_item['Max'] = item['value']
|
|
798
|
-
parsed_item['Min'] = item['value']
|
|
799
|
-
parsed_item['Mean'] = item['value']
|
|
800
|
-
parsed_item['Norm'] = item['value']
|
|
801
|
-
parsed_item['data_name'] = '-1'
|
|
802
|
-
item_list.append(parsed_item)
|
|
803
|
-
else:
|
|
804
|
-
resolve_api_special_parameters(item, full_op_name, item_list)
|
|
805
|
-
else:
|
|
806
|
-
for j, item_spec in enumerate(item):
|
|
807
|
-
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
|
|
808
|
-
return item_list
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
812
|
-
"""
|
|
813
|
-
Function Description:
|
|
814
|
-
解析下面格式的数据, 是api参数的一种特殊格式
|
|
815
|
-
{
|
|
816
|
-
"last_hidden_state": {
|
|
817
|
-
"type": "torch.Tensor",
|
|
818
|
-
"dtype": "torch.bfloat16",
|
|
819
|
-
...
|
|
820
|
-
},
|
|
821
|
-
"loss": {
|
|
822
|
-
"type": "torch.Tensor",
|
|
823
|
-
"dtype": "torch.float32",
|
|
824
|
-
...
|
|
825
|
-
}
|
|
826
|
-
}
|
|
827
|
-
Parameter:
|
|
828
|
-
data_dict: 字典格式的数据
|
|
829
|
-
full_op_name: 参数的全名字符串
|
|
830
|
-
item_list: 参数信息集合
|
|
831
|
-
"""
|
|
832
|
-
for key, value in data_dict.items():
|
|
833
|
-
if isinstance(value, dict):
|
|
834
|
-
parsed_item = value
|
|
835
|
-
parts = full_op_name.split(".")
|
|
836
|
-
parts.insert(-1, key)
|
|
837
|
-
full_op_name_new = ".".join(parts)
|
|
838
|
-
parsed_item['full_op_name'] = full_op_name_new
|
|
839
|
-
item_list.append(parsed_item)
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
def read_op(op_data, op_name):
|
|
843
|
-
op_parsed_list = []
|
|
844
|
-
if 'forward' in op_name:
|
|
845
|
-
if 'input_args' in op_data:
|
|
846
|
-
input_item = op_data['input_args']
|
|
847
|
-
input_parsed_list = op_item_parse(input_item, op_name + '_input', None)
|
|
848
|
-
op_parsed_list = input_parsed_list.copy()
|
|
849
|
-
input_parsed_list.clear()
|
|
850
|
-
if 'input_kwargs' in op_data:
|
|
851
|
-
kwargs_item = op_data['input_kwargs']
|
|
852
|
-
if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
|
|
853
|
-
kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '_input', None)
|
|
854
|
-
op_parsed_list += kwarg_parsed_list
|
|
855
|
-
kwarg_parsed_list.clear()
|
|
856
|
-
elif kwargs_item:
|
|
857
|
-
for kwarg in kwargs_item:
|
|
858
|
-
kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '_input.' + kwarg, None)
|
|
859
|
-
op_parsed_list += kwarg_parsed_list
|
|
860
|
-
kwarg_parsed_list.clear()
|
|
861
|
-
if 'output' in op_data:
|
|
862
|
-
output_item = op_data['output']
|
|
863
|
-
output_parsed_list = op_item_parse(output_item, op_name + '_output', None)
|
|
864
|
-
op_parsed_list += output_parsed_list
|
|
865
|
-
output_parsed_list.clear()
|
|
866
|
-
if 'backward' in op_name:
|
|
867
|
-
if 'grad_input' in op_data:
|
|
868
|
-
input_item = op_data['grad_input']
|
|
869
|
-
input_parsed_list = op_item_parse(input_item, op_name + '_input', None)
|
|
870
|
-
op_parsed_list = input_parsed_list.copy()
|
|
871
|
-
input_parsed_list.clear()
|
|
872
|
-
if 'grad_output' in op_data:
|
|
873
|
-
output_item = op_data['grad_output']
|
|
874
|
-
output_parsed_list = op_item_parse(output_item, op_name + '_output', None)
|
|
875
|
-
op_parsed_list += output_parsed_list
|
|
876
|
-
output_parsed_list.clear()
|
|
877
|
-
return op_parsed_list
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
|
|
881
|
-
npu_json_handle, bench_json_handle, stack_json_handle = file_handles
|
|
882
|
-
npu_json_data = json.load(npu_json_handle)
|
|
883
|
-
bench_json_data = json.load(bench_json_handle)
|
|
884
|
-
stack_json_data = json.load(stack_json_handle)
|
|
885
|
-
|
|
886
|
-
if fuzzy_match:
|
|
887
|
-
logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
|
|
888
|
-
|
|
889
|
-
npu_ops_queue = []
|
|
890
|
-
bench_ops_queue = []
|
|
891
|
-
result = []
|
|
892
|
-
|
|
893
|
-
ops_npu_iter = iter(npu_json_data['data'])
|
|
894
|
-
ops_bench_iter = iter(bench_json_data['data'])
|
|
895
|
-
read_err_npu = True
|
|
896
|
-
read_err_bench = True
|
|
897
|
-
last_npu_ops_len = 0
|
|
898
|
-
last_bench_ops_len = 0
|
|
899
|
-
|
|
900
|
-
while True:
|
|
901
|
-
if not read_err_npu and not read_err_bench:
|
|
902
|
-
break
|
|
903
|
-
try:
|
|
904
|
-
last_npu_ops_len = len(npu_ops_queue)
|
|
905
|
-
op_name_npu = next(ops_npu_iter)
|
|
906
|
-
read_err_npu = True
|
|
907
|
-
|
|
908
|
-
npu_op_data = npu_json_data['data'][op_name_npu]
|
|
909
|
-
npu_op_parsed_list = read_op(npu_op_data, op_name_npu)
|
|
910
|
-
if op_name_npu in stack_json_data:
|
|
911
|
-
npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': stack_json_data[op_name_npu]})
|
|
912
|
-
else:
|
|
913
|
-
npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': None})
|
|
914
|
-
|
|
915
|
-
npu_merge_list = merge_tensor(npu_op_parsed_list, summary_compare, md5_compare)
|
|
916
|
-
if npu_merge_list:
|
|
917
|
-
npu_ops_queue.append(npu_merge_list)
|
|
918
|
-
except StopIteration:
|
|
919
|
-
read_err_npu = False
|
|
920
|
-
try:
|
|
921
|
-
last_bench_ops_len = len(bench_ops_queue)
|
|
922
|
-
op_name_bench = next(ops_bench_iter)
|
|
923
|
-
|
|
924
|
-
bench_op_data = bench_json_data['data'][op_name_bench]
|
|
925
|
-
bench_op_parsed_list = read_op(bench_op_data, op_name_bench)
|
|
926
|
-
if op_name_bench in stack_json_data:
|
|
927
|
-
bench_op_parsed_list.append(
|
|
928
|
-
{'full_op_name': op_name_bench, 'full_info': stack_json_data[op_name_bench]})
|
|
929
|
-
else:
|
|
930
|
-
bench_op_parsed_list.append({'full_op_name': op_name_bench, 'full_info': None})
|
|
931
|
-
|
|
932
|
-
bench_merge_list = merge_tensor(bench_op_parsed_list, summary_compare, md5_compare)
|
|
933
|
-
if bench_merge_list:
|
|
934
|
-
bench_ops_queue.append(bench_merge_list)
|
|
935
|
-
except StopIteration:
|
|
936
|
-
read_err_bench = False
|
|
937
|
-
|
|
938
|
-
# merge all boolean expressions
|
|
939
|
-
both_empty = not npu_ops_queue and not bench_ops_queue
|
|
940
|
-
no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
|
|
941
|
-
if both_empty or no_change:
|
|
942
|
-
continue
|
|
943
|
-
|
|
944
|
-
n_match_point, b_match_point = match_op(npu_ops_queue, bench_ops_queue, fuzzy_match)
|
|
945
|
-
if n_match_point == -1 and b_match_point == -1:
|
|
946
|
-
continue
|
|
947
|
-
n_match_data = npu_ops_queue[n_match_point]
|
|
948
|
-
b_match_data = bench_ops_queue[b_match_point]
|
|
949
|
-
un_match_data = npu_ops_queue[0: n_match_point]
|
|
950
|
-
for npu_data in un_match_data:
|
|
951
|
-
get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
|
|
952
|
-
get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare)
|
|
953
|
-
del npu_ops_queue[0: n_match_point + 1]
|
|
954
|
-
del bench_ops_queue[0: b_match_point + 1]
|
|
955
|
-
if npu_ops_queue:
|
|
956
|
-
for npu_data in npu_ops_queue:
|
|
957
|
-
get_un_match_accuracy(result, npu_data, md5_compare, summary_compare)
|
|
958
|
-
|
|
959
|
-
header = []
|
|
960
|
-
if md5_compare:
|
|
961
|
-
header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
|
|
962
|
-
elif summary_compare:
|
|
963
|
-
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
964
|
-
else:
|
|
965
|
-
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
966
|
-
|
|
967
|
-
all_mode_bool = not (summary_compare or md5_compare)
|
|
968
|
-
if stack_mode:
|
|
969
|
-
if all_mode_bool:
|
|
970
|
-
header.append(CompareConst.STACK)
|
|
971
|
-
header.append(CompareConst.DATA_NAME)
|
|
972
|
-
else:
|
|
973
|
-
header.append(CompareConst.STACK)
|
|
974
|
-
else:
|
|
975
|
-
if all_mode_bool:
|
|
976
|
-
for row in result:
|
|
977
|
-
del row[-2]
|
|
978
|
-
header.append(CompareConst.DATA_NAME)
|
|
979
|
-
else:
|
|
980
|
-
for row in result:
|
|
981
|
-
del row[-1]
|
|
982
|
-
|
|
983
|
-
result_df = pd.DataFrame(result, columns=header)
|
|
984
|
-
return result_df
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
988
|
-
index_out = 0
|
|
989
|
-
npu_stack_info = n_dict.get("stack_info", None)
|
|
990
|
-
bench_name, bench_type, bench_shape = CompareConst.NAN, CompareConst.NAN, CompareConst.NAN
|
|
991
|
-
err_msg = CompareConst.NO_BENCH
|
|
992
|
-
accuracy_check_res = CompareConst.NAN
|
|
993
|
-
for index, n_name in enumerate(n_dict["op_name"]):
|
|
994
|
-
if n_name.find("input") != -1:
|
|
995
|
-
n_struct = n_dict["input_struct"][index]
|
|
996
|
-
else:
|
|
997
|
-
n_struct = n_dict["output_struct"][index_out]
|
|
998
|
-
index_out += 1
|
|
999
|
-
|
|
1000
|
-
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
1001
|
-
if md5_compare:
|
|
1002
|
-
result_item.extend([CompareConst.NAN] * 3)
|
|
1003
|
-
if npu_stack_info and index == 0:
|
|
1004
|
-
result_item.extend(npu_stack_info)
|
|
1005
|
-
result.append(result_item)
|
|
1006
|
-
continue
|
|
1007
|
-
if summary_compare:
|
|
1008
|
-
result_item.extend([CompareConst.NAN] * 8)
|
|
1009
|
-
else:
|
|
1010
|
-
result_item.extend([CompareConst.NAN] * 5)
|
|
1011
|
-
summary_data = n_dict.get("summary")[index]
|
|
1012
|
-
result_item.extend(summary_data)
|
|
1013
|
-
summary_data = [CompareConst.NAN] * 4
|
|
1014
|
-
result_item.extend(summary_data)
|
|
1015
|
-
result_item.append(accuracy_check_res)
|
|
1016
|
-
result_item.append(err_msg)
|
|
1017
|
-
if npu_stack_info and index == 0:
|
|
1018
|
-
result_item.extend(npu_stack_info)
|
|
1019
|
-
if not md5_compare and not summary_compare and result_item[1] == CompareConst.NAN:
|
|
1020
|
-
if index == 0:
|
|
1021
|
-
result_item.extend(["-1"])
|
|
1022
|
-
else:
|
|
1023
|
-
result_item.extend([CompareConst.NONE, "-1"])
|
|
1024
|
-
result.append(result_item)
|