mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,88 +1,148 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import multiprocessing
|
|
2
17
|
import os
|
|
3
|
-
import
|
|
18
|
+
import re
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
|
|
4
21
|
import pandas as pd
|
|
5
|
-
from msprobe.core.
|
|
22
|
+
from msprobe.core.advisor.advisor import Advisor
|
|
6
23
|
from msprobe.core.common.const import CompareConst, Const
|
|
7
24
|
from msprobe.core.common.exceptions import FileCheckException
|
|
8
|
-
from msprobe.core.common.
|
|
9
|
-
from msprobe.core.common.utils import add_time_with_xlsx, CompareException
|
|
25
|
+
from msprobe.core.common.file_utils import load_json
|
|
10
26
|
from msprobe.core.common.file_utils import remove_path
|
|
11
|
-
from msprobe.core.
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid, safe_get_value
|
|
29
|
+
from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
|
|
30
|
+
check_stack_json_str
|
|
12
31
|
from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
|
|
13
|
-
from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
|
|
14
32
|
from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
|
|
15
33
|
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \
|
|
16
34
|
get_error_message
|
|
17
|
-
from msprobe.core.
|
|
35
|
+
from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy, \
|
|
36
|
+
get_rela_diff_summary_mode, print_compare_ends_info
|
|
37
|
+
from tqdm import tqdm
|
|
18
38
|
|
|
19
39
|
|
|
20
40
|
class Comparator:
|
|
21
|
-
|
|
41
|
+
|
|
22
42
|
def __init__(self):
|
|
23
43
|
pass
|
|
24
|
-
|
|
25
|
-
@
|
|
26
|
-
def
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
|
|
47
|
+
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
48
|
+
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
49
|
+
|
|
50
|
+
if len(npu_struct) < 3 or len(bench_struct) < 3:
|
|
51
|
+
logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
|
|
52
|
+
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
|
|
53
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
54
|
+
|
|
55
|
+
result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
|
|
56
|
+
npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
|
|
57
|
+
CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
|
|
58
|
+
|
|
59
|
+
if len(args) >= 2 and args[0]:
|
|
60
|
+
result_item.extend(args[1])
|
|
32
61
|
else:
|
|
33
|
-
|
|
62
|
+
result_item.append(CompareConst.NONE)
|
|
63
|
+
return result_item
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
|
|
67
|
+
err_msg = ""
|
|
68
|
+
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
69
|
+
bench_summary_data, err_msg)
|
|
70
|
+
result_item.append(accuracy_check)
|
|
71
|
+
result_item.append(err_msg)
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def _generate_na_data(ops_all):
|
|
75
|
+
if not ops_all:
|
|
76
|
+
return {}
|
|
77
|
+
key = next(iter(ops_all))
|
|
78
|
+
value = deepcopy(ops_all[key])
|
|
79
|
+
for k, v in value.items():
|
|
80
|
+
if isinstance(v, tuple):
|
|
81
|
+
value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
|
|
82
|
+
elif isinstance(v, list):
|
|
83
|
+
value[k] = [CompareConst.N_A] * len(v)
|
|
84
|
+
else:
|
|
85
|
+
value[k] = CompareConst.N_A
|
|
86
|
+
return value
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def make_result_table(cls, result, stack_mode, dump_mode):
|
|
90
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode][:]
|
|
34
91
|
|
|
35
|
-
all_mode_bool = not (summary_compare or md5_compare)
|
|
36
92
|
if stack_mode:
|
|
37
|
-
|
|
38
|
-
|
|
93
|
+
header.append(CompareConst.STACK)
|
|
94
|
+
if dump_mode == Const.ALL:
|
|
39
95
|
header.append(CompareConst.DATA_NAME)
|
|
40
|
-
else:
|
|
41
|
-
header.append(CompareConst.STACK)
|
|
42
96
|
else:
|
|
43
|
-
if
|
|
97
|
+
if dump_mode == Const.ALL:
|
|
44
98
|
for row in result:
|
|
45
|
-
del row[-2]
|
|
99
|
+
del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
|
|
46
100
|
header.append(CompareConst.DATA_NAME)
|
|
47
101
|
else:
|
|
48
102
|
for row in result:
|
|
49
|
-
del row[-1]
|
|
50
|
-
result_df = pd.DataFrame(result, columns=header)
|
|
51
|
-
return result_df
|
|
52
|
-
|
|
103
|
+
del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
|
|
104
|
+
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
105
|
+
return result_df
|
|
106
|
+
|
|
53
107
|
@classmethod
|
|
54
|
-
def gen_merge_list(
|
|
108
|
+
def gen_merge_list(cls, json_data, op_name, stack_json_data, dump_mode):
|
|
55
109
|
op_data = json_data['data'][op_name]
|
|
110
|
+
check_dump_json_str(op_data, op_name)
|
|
56
111
|
op_parsed_list = read_op(op_data, op_name)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
112
|
+
|
|
113
|
+
stack_info = stack_json_data.get(op_name)
|
|
114
|
+
if stack_info is not None:
|
|
115
|
+
check_stack_json_str(stack_info, op_name)
|
|
116
|
+
op_parsed_list.append({
|
|
117
|
+
'full_op_name': op_name,
|
|
118
|
+
'full_info': stack_info
|
|
119
|
+
})
|
|
120
|
+
|
|
121
|
+
merge_list = merge_tensor(op_parsed_list, dump_mode)
|
|
63
122
|
return merge_list
|
|
64
|
-
|
|
123
|
+
|
|
65
124
|
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
graph_mode = check_graph_mode(
|
|
69
|
-
|
|
70
|
-
|
|
125
|
+
npu_op_name = npu_dict[CompareConst.OP_NAME]
|
|
126
|
+
bench_op_name = bench_dict[CompareConst.OP_NAME]
|
|
127
|
+
graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
|
|
128
|
+
safe_get_value(bench_op_name, 0, "bench_op_name"))
|
|
129
|
+
|
|
130
|
+
frame_name = getattr(self, "frame_name")
|
|
71
131
|
if frame_name == "PTComparator":
|
|
72
132
|
from msprobe.pytorch.compare.match import graph_mapping
|
|
73
133
|
if graph_mode:
|
|
74
|
-
return graph_mapping.match(
|
|
134
|
+
return graph_mapping.match(npu_op_name[0], bench_op_name[0])
|
|
75
135
|
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
76
136
|
if not fuzzy_match:
|
|
77
|
-
return
|
|
137
|
+
return npu_op_name == bench_op_name and struct_match
|
|
78
138
|
is_match = True
|
|
79
139
|
try:
|
|
80
|
-
is_match = fuzzy_check_op(
|
|
140
|
+
is_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
81
141
|
except Exception as err:
|
|
82
|
-
logger.warning("%s and %s can not fuzzy match." % (
|
|
142
|
+
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
83
143
|
is_match = False
|
|
84
144
|
return is_match and struct_match
|
|
85
|
-
|
|
145
|
+
|
|
86
146
|
def match_op(self, npu_queue, bench_queue, fuzzy_match):
|
|
87
147
|
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
88
148
|
if self.check_op(npu_queue[-1], b_op, fuzzy_match):
|
|
@@ -93,12 +153,12 @@ class Comparator:
|
|
|
93
153
|
if self.check_op(n_op, bench_queue[-1], fuzzy_match):
|
|
94
154
|
return n_index, len(bench_queue) - 1
|
|
95
155
|
return -1, -1
|
|
96
|
-
|
|
97
|
-
def compare_process(self,
|
|
98
|
-
|
|
99
|
-
npu_json_data =
|
|
100
|
-
bench_json_data =
|
|
101
|
-
stack_json_data =
|
|
156
|
+
|
|
157
|
+
def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
|
|
158
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
159
|
+
npu_json_data = load_json(npu_json_path)
|
|
160
|
+
bench_json_data = load_json(bench_json_path)
|
|
161
|
+
stack_json_data = load_json(stack_json_path)
|
|
102
162
|
|
|
103
163
|
if fuzzy_match:
|
|
104
164
|
logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
|
|
@@ -114,14 +174,18 @@ class Comparator:
|
|
|
114
174
|
last_npu_ops_len = 0
|
|
115
175
|
last_bench_ops_len = 0
|
|
116
176
|
|
|
177
|
+
npu_api_nums = len(npu_json_data['data'])
|
|
178
|
+
progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
|
|
179
|
+
|
|
117
180
|
while True:
|
|
118
181
|
if not read_err_npu and not read_err_bench:
|
|
119
182
|
break
|
|
120
183
|
try:
|
|
121
184
|
last_npu_ops_len = len(npu_ops_queue)
|
|
122
185
|
op_name_npu = next(ops_npu_iter)
|
|
186
|
+
check_op_str_pattern_valid(op_name_npu)
|
|
123
187
|
read_err_npu = True
|
|
124
|
-
npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,
|
|
188
|
+
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data, dump_mode)
|
|
125
189
|
if npu_merge_list:
|
|
126
190
|
npu_ops_queue.append(npu_merge_list)
|
|
127
191
|
except StopIteration:
|
|
@@ -129,12 +193,15 @@ class Comparator:
|
|
|
129
193
|
try:
|
|
130
194
|
last_bench_ops_len = len(bench_ops_queue)
|
|
131
195
|
op_name_bench = next(ops_bench_iter)
|
|
132
|
-
|
|
196
|
+
check_op_str_pattern_valid(op_name_bench)
|
|
197
|
+
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data, dump_mode)
|
|
133
198
|
if bench_merge_list:
|
|
134
199
|
bench_ops_queue.append(bench_merge_list)
|
|
135
200
|
except StopIteration:
|
|
136
201
|
read_err_bench = False
|
|
137
202
|
|
|
203
|
+
progress_bar.update(1)
|
|
204
|
+
|
|
138
205
|
# merge all boolean expressions
|
|
139
206
|
both_empty = not npu_ops_queue and not bench_ops_queue
|
|
140
207
|
no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
|
|
@@ -153,24 +220,144 @@ class Comparator:
|
|
|
153
220
|
b_match_data = bench_ops_queue[b_match_point]
|
|
154
221
|
un_match_data = npu_ops_queue[0: n_match_point]
|
|
155
222
|
for npu_data in un_match_data:
|
|
156
|
-
get_un_match_accuracy(result, npu_data,
|
|
157
|
-
get_accuracy(result, n_match_data, b_match_data,
|
|
223
|
+
get_un_match_accuracy(result, npu_data, dump_mode)
|
|
224
|
+
get_accuracy(result, n_match_data, b_match_data, dump_mode)
|
|
158
225
|
del npu_ops_queue[0: n_match_point + 1]
|
|
159
226
|
del bench_ops_queue[0: b_match_point + 1]
|
|
227
|
+
progress_bar.close()
|
|
160
228
|
if npu_ops_queue:
|
|
161
229
|
for npu_data in npu_ops_queue:
|
|
162
|
-
get_un_match_accuracy(result, npu_data,
|
|
163
|
-
|
|
164
|
-
result_df = self.make_result_table(result,
|
|
230
|
+
get_un_match_accuracy(result, npu_data, dump_mode)
|
|
231
|
+
|
|
232
|
+
result_df = self.make_result_table(result, stack_mode, dump_mode)
|
|
165
233
|
return result_df
|
|
166
|
-
|
|
167
|
-
def
|
|
234
|
+
|
|
235
|
+
def merge_data(self, json_data, stack_json_data, dump_mode):
|
|
236
|
+
ops_all = {}
|
|
237
|
+
for op_name in json_data.get('data', {}):
|
|
238
|
+
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, dump_mode)
|
|
239
|
+
if merge_list:
|
|
240
|
+
input_index, output_index = 0, 0
|
|
241
|
+
for index, input_or_output in enumerate(merge_list[CompareConst.OP_NAME]):
|
|
242
|
+
input_or_output_list = input_or_output.split(Const.SEP)
|
|
243
|
+
data_name = merge_list.get('data_name')
|
|
244
|
+
data_name = data_name[index] if data_name else None
|
|
245
|
+
if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
|
|
246
|
+
ops_all[input_or_output] = {
|
|
247
|
+
CompareConst.STRUCT: safe_get_value(merge_list, input_index, "merge_list",
|
|
248
|
+
key=CompareConst.INPUT_STRUCT),
|
|
249
|
+
CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
|
|
250
|
+
key=CompareConst.SUMMARY),
|
|
251
|
+
'data_name': data_name,
|
|
252
|
+
'stack_info': merge_list.get('stack_info')
|
|
253
|
+
}
|
|
254
|
+
input_index += 1
|
|
255
|
+
|
|
256
|
+
elif Const.OUTPUT in input_or_output_list:
|
|
257
|
+
ops_all[input_or_output] = {
|
|
258
|
+
CompareConst.STRUCT: safe_get_value(merge_list, output_index, "merge_list",
|
|
259
|
+
key=CompareConst.OUTPUT_STRUCT),
|
|
260
|
+
CompareConst.SUMMARY: safe_get_value(merge_list, index, "merge_list",
|
|
261
|
+
key=CompareConst.SUMMARY),
|
|
262
|
+
'data_name': data_name,
|
|
263
|
+
'stack_info': merge_list.get('stack_info')
|
|
264
|
+
}
|
|
265
|
+
output_index += 1
|
|
266
|
+
return ops_all
|
|
267
|
+
|
|
268
|
+
def get_accuracy(self, npu_ops_all, bench_ops_all, dump_mode):
|
|
269
|
+
result = []
|
|
270
|
+
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
271
|
+
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
272
|
+
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
273
|
+
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
274
|
+
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
275
|
+
has_stack = npu_stack_info and bench_stack_info
|
|
276
|
+
if dump_mode == Const.MD5:
|
|
277
|
+
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
278
|
+
bench_ops_all, has_stack, npu_stack_info))
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
282
|
+
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
283
|
+
|
|
284
|
+
if len(npu_struct) < 2 or len(bench_struct) < 2:
|
|
285
|
+
logger.error(
|
|
286
|
+
f"The length of npu_struct and bench_struct must be >= 2, "
|
|
287
|
+
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
|
|
288
|
+
f"Please check!"
|
|
289
|
+
)
|
|
290
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
291
|
+
|
|
292
|
+
base_result_item = [
|
|
293
|
+
ms_op_name, bench_op_name,
|
|
294
|
+
npu_struct[0],
|
|
295
|
+
bench_struct[0],
|
|
296
|
+
npu_struct[1],
|
|
297
|
+
bench_struct[1]
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
if dump_mode == Const.SUMMARY:
|
|
301
|
+
result_item = base_result_item + [" "] * 8
|
|
302
|
+
else:
|
|
303
|
+
result_item = base_result_item + [" "] * 5
|
|
304
|
+
|
|
305
|
+
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
306
|
+
result_item.extend(npu_summary_data)
|
|
307
|
+
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
308
|
+
result_item.extend(bench_summary_data)
|
|
309
|
+
if dump_mode == Const.SUMMARY:
|
|
310
|
+
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
311
|
+
else:
|
|
312
|
+
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
313
|
+
result_item.append("")
|
|
314
|
+
if has_stack:
|
|
315
|
+
result_item.extend(npu_stack_info)
|
|
316
|
+
else:
|
|
317
|
+
result_item.append(CompareConst.NONE)
|
|
318
|
+
if dump_mode == Const.ALL:
|
|
319
|
+
result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
|
|
320
|
+
result.append(result_item)
|
|
321
|
+
elif ms_op_name not in npu_ops_all:
|
|
322
|
+
logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
|
|
323
|
+
elif bench_op_name not in npu_ops_all:
|
|
324
|
+
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
325
|
+
return result
|
|
326
|
+
|
|
327
|
+
def compare_process_custom(self, file_lists, stack_mode, dump_mode):
|
|
328
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
329
|
+
npu_json_data = load_json(npu_json_path)
|
|
330
|
+
bench_json_data = load_json(bench_json_path)
|
|
331
|
+
stack_json_data = load_json(stack_json_path)
|
|
332
|
+
|
|
333
|
+
npu_ops_all = self.merge_data(npu_json_data, stack_json_data, dump_mode)
|
|
334
|
+
bench_ops_all = self.merge_data(bench_json_data, stack_json_data, dump_mode)
|
|
335
|
+
|
|
336
|
+
result = self.get_accuracy(npu_ops_all, bench_ops_all, dump_mode)
|
|
337
|
+
result_df = self.make_result_table(result, stack_mode, dump_mode)
|
|
338
|
+
return result_df
|
|
339
|
+
|
|
340
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
|
|
341
|
+
"""
|
|
342
|
+
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
343
|
+
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
344
|
+
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
345
|
+
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
346
|
+
:param bench_data: bench的dump数据中"data"字段
|
|
347
|
+
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
348
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
|
|
349
|
+
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
350
|
+
"""
|
|
168
351
|
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
169
|
-
data_name = npu_bench_name_list
|
|
352
|
+
data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
|
|
170
353
|
error_file, relative_err, error_flag = None, None, False
|
|
354
|
+
bench_data_name = get_bench_data_name(bench_op_name, bench_data)
|
|
171
355
|
if data_name == '-1' or data_name == -1: # 没有真实数据路径
|
|
172
356
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
173
357
|
error_flag = True
|
|
358
|
+
elif not bench_data_name:
|
|
359
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
360
|
+
error_file = 'no_bench_data'
|
|
174
361
|
else:
|
|
175
362
|
try:
|
|
176
363
|
read_npy_data = getattr(self, "read_npy_data")
|
|
@@ -178,17 +365,18 @@ class Comparator:
|
|
|
178
365
|
if frame_name == "MSComparator":
|
|
179
366
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
|
|
180
367
|
if self.cross_frame:
|
|
181
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
368
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
|
|
369
|
+
load_pt_file=True)
|
|
182
370
|
else:
|
|
183
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
371
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
|
|
184
372
|
else:
|
|
185
373
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
|
|
186
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
374
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
|
|
187
375
|
except IOError as error:
|
|
188
376
|
error_file = error.filename
|
|
189
377
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
190
378
|
error_flag = True
|
|
191
|
-
except FileCheckException:
|
|
379
|
+
except (FileCheckException, CompareException):
|
|
192
380
|
error_file = data_name
|
|
193
381
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
194
382
|
error_flag = True
|
|
@@ -205,7 +393,7 @@ class Comparator:
|
|
|
205
393
|
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
206
394
|
result_list.append(err_msg)
|
|
207
395
|
return result_list
|
|
208
|
-
|
|
396
|
+
|
|
209
397
|
def compare_core(self, input_parma, output_path, **kwargs):
|
|
210
398
|
"""
|
|
211
399
|
Compares data from multiple JSON files and generates a comparison report.
|
|
@@ -219,8 +407,7 @@ class Comparator:
|
|
|
219
407
|
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
220
408
|
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
221
409
|
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
222
|
-
-
|
|
223
|
-
- md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
|
|
410
|
+
- dump_mode (str): ALL, SUMMARY, MD5.
|
|
224
411
|
|
|
225
412
|
Returns:
|
|
226
413
|
"""
|
|
@@ -229,29 +416,43 @@ class Comparator:
|
|
|
229
416
|
auto_analyze = kwargs.get('auto_analyze', True)
|
|
230
417
|
suffix = kwargs.get('suffix', '')
|
|
231
418
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
232
|
-
|
|
233
|
-
md5_compare = kwargs.get('md5_compare', False)
|
|
419
|
+
dump_mode = kwargs.get('dump_mode', None)
|
|
234
420
|
|
|
235
421
|
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
236
422
|
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
237
423
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
238
424
|
remove_path(file_path)
|
|
239
|
-
highlight_dict = {
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
425
|
+
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
426
|
+
|
|
427
|
+
npu_json = input_parma.get("npu_json_path")
|
|
428
|
+
bench_json = input_parma.get("bench_json_path")
|
|
429
|
+
stack_json = input_parma.get("stack_json_path")
|
|
430
|
+
if self.data_mapping:
|
|
431
|
+
result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode, dump_mode)
|
|
432
|
+
else:
|
|
433
|
+
result_df = self.compare_process(
|
|
434
|
+
[npu_json, bench_json, stack_json],
|
|
435
|
+
stack_mode,
|
|
436
|
+
fuzzy_match,
|
|
437
|
+
dump_mode
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if not result_df.values.tolist():
|
|
441
|
+
logger.warning("Can`t match any op.")
|
|
442
|
+
return
|
|
443
|
+
|
|
444
|
+
if dump_mode == Const.ALL:
|
|
445
|
+
result_df = self.do_multi_process(input_parma, result_df)
|
|
446
|
+
|
|
447
|
+
find_compare_result_error_rows(result_df, highlight_dict, dump_mode)
|
|
250
448
|
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
449
|
+
|
|
251
450
|
if auto_analyze:
|
|
252
|
-
advisor = Advisor(result_df, output_path)
|
|
451
|
+
advisor = Advisor(result_df, output_path, suffix)
|
|
253
452
|
advisor.analysis()
|
|
254
|
-
|
|
453
|
+
|
|
454
|
+
print_compare_ends_info()
|
|
455
|
+
|
|
255
456
|
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
256
457
|
cos_result = []
|
|
257
458
|
max_err_result = []
|
|
@@ -260,18 +461,22 @@ class Comparator:
|
|
|
260
461
|
one_thousand_err_ratio_result = []
|
|
261
462
|
five_thousand_err_ratio_result = []
|
|
262
463
|
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
464
|
+
bench_data = load_json(input_param.get("bench_json_path")).get('data')
|
|
263
465
|
for i in range(len(result_df)):
|
|
264
466
|
npu_op_name = result_df.iloc[i, 0]
|
|
265
467
|
bench_op_name = result_df.iloc[i, 1]
|
|
266
468
|
if is_print_compare_log:
|
|
267
469
|
logger.info("start compare: {}".format(npu_op_name))
|
|
268
|
-
|
|
269
|
-
|
|
470
|
+
|
|
471
|
+
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
|
|
472
|
+
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
|
|
473
|
+
|
|
270
474
|
if is_print_compare_log:
|
|
271
475
|
logger.info(
|
|
272
|
-
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {},
|
|
273
|
-
|
|
274
|
-
|
|
476
|
+
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
477
|
+
one_thousand_err_ratio {}, "
|
|
478
|
+
"five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
|
|
479
|
+
err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
|
|
275
480
|
cos_result.append(cos_sim)
|
|
276
481
|
max_err_result.append(max_abs_err)
|
|
277
482
|
max_relative_err_result.append(max_relative_err)
|
|
@@ -288,13 +493,46 @@ class Comparator:
|
|
|
288
493
|
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
289
494
|
)
|
|
290
495
|
|
|
291
|
-
return _save_cmp_result(idx, cr, result_df, lock)
|
|
292
|
-
|
|
293
|
-
def
|
|
496
|
+
return _save_cmp_result(idx, cr, result_df, lock)
|
|
497
|
+
|
|
498
|
+
def do_multi_process(self, input_parma, result_df):
|
|
294
499
|
try:
|
|
295
|
-
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
500
|
+
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
501
|
+
multiprocessing.Manager().RLock())
|
|
296
502
|
return result_df
|
|
297
503
|
except ValueError as e:
|
|
298
504
|
logger.error('result dataframe is not found.')
|
|
299
505
|
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
300
|
-
|
|
506
|
+
|
|
507
|
+
def get_bench_data_name(bench_op_name, bench_data):
|
|
508
|
+
bench_name_list = re.split(r'\.(input|output|kwargs)\.', bench_op_name)
|
|
509
|
+
bench_data_bundle = bench_data.get(bench_name_list[0], {})
|
|
510
|
+
if not bench_data_bundle or len(bench_name_list) < 3:
|
|
511
|
+
return None
|
|
512
|
+
layers = bench_name_list[2].split(Const.SEP)
|
|
513
|
+
|
|
514
|
+
def get(key, container):
|
|
515
|
+
if isinstance(container, dict):
|
|
516
|
+
return container.get(key)
|
|
517
|
+
if isinstance(container, list):
|
|
518
|
+
try:
|
|
519
|
+
return container[int(key)]
|
|
520
|
+
except (ValueError, IndexError):
|
|
521
|
+
return None
|
|
522
|
+
return None
|
|
523
|
+
|
|
524
|
+
def get_by_layer(container):
|
|
525
|
+
data = container
|
|
526
|
+
for layer in layers:
|
|
527
|
+
data = get(layer, data)
|
|
528
|
+
return get(CompareConst.DATA_NAME.lower(), data)
|
|
529
|
+
|
|
530
|
+
if Const.INPUT == bench_name_list[1]:
|
|
531
|
+
return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
|
|
532
|
+
elif Const.KWARGS == bench_name_list[1]:
|
|
533
|
+
return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
|
|
534
|
+
elif Const.OUTPUT == bench_name_list[1]:
|
|
535
|
+
return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
|
|
536
|
+
else:
|
|
537
|
+
return None
|
|
538
|
+
|