mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,41 +15,57 @@
|
|
|
15
15
|
|
|
16
16
|
import multiprocessing
|
|
17
17
|
import os
|
|
18
|
+
import re
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
|
|
18
21
|
import pandas as pd
|
|
19
22
|
from tqdm import tqdm
|
|
20
|
-
|
|
23
|
+
|
|
24
|
+
from msprobe.core.advisor.advisor import Advisor
|
|
21
25
|
from msprobe.core.common.const import CompareConst, Const
|
|
22
26
|
from msprobe.core.common.exceptions import FileCheckException
|
|
27
|
+
from msprobe.core.common.file_utils import load_json, remove_path
|
|
23
28
|
from msprobe.core.common.log import logger
|
|
24
|
-
from msprobe.core.common.utils import add_time_with_xlsx,
|
|
25
|
-
from msprobe.core.
|
|
26
|
-
|
|
27
|
-
check_stack_json_str
|
|
29
|
+
from msprobe.core.common.utils import CompareException, add_time_with_xlsx, check_op_str_pattern_valid, safe_get_value
|
|
30
|
+
from msprobe.core.compare.check import check_dump_json_str, check_graph_mode, check_stack_json_str, \
|
|
31
|
+
check_struct_match, fuzzy_check_op
|
|
28
32
|
from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
|
|
29
|
-
from msprobe.core.compare.
|
|
30
|
-
from msprobe.core.compare.
|
|
31
|
-
from msprobe.core.compare.
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _handle_multi_process, _save_cmp_result
|
|
34
|
+
from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_flag_and_msg
|
|
35
|
+
from msprobe.core.compare.utils import get_accuracy, get_rela_diff_summary_mode, get_un_match_accuracy, merge_tensor, \
|
|
36
|
+
print_compare_ends_info, read_op, get_name_and_state, reorder_op_x_list
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ModeConfig:
|
|
40
|
+
def __init__(self, stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=None):
|
|
41
|
+
self.stack_mode = stack_mode
|
|
42
|
+
self.auto_analyze = auto_analyze
|
|
43
|
+
self.fuzzy_match = fuzzy_match
|
|
44
|
+
self.dump_mode = dump_mode
|
|
34
45
|
|
|
35
46
|
|
|
36
47
|
class Comparator:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
48
|
+
def __init__(self, mode_config: ModeConfig):
|
|
49
|
+
self.stack_mode = mode_config.stack_mode
|
|
50
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
51
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
52
|
+
self.dump_mode = mode_config.dump_mode
|
|
40
53
|
|
|
41
54
|
@staticmethod
|
|
42
55
|
def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
56
|
+
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
57
|
+
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
58
|
+
|
|
59
|
+
if len(npu_struct) < 3 or len(bench_struct) < 3:
|
|
60
|
+
logger.error(f"The length of npu_struct and bench_struct must be >= 3, "
|
|
61
|
+
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. Please check!")
|
|
62
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
63
|
+
|
|
64
|
+
result_item = [ms_op_name, bench_op_name, npu_struct[0], bench_struct[0],
|
|
65
|
+
npu_struct[1], bench_struct[1], npu_struct[2], bench_struct[2],
|
|
66
|
+
CompareConst.PASS if npu_struct[2] == bench_struct[2] else CompareConst.DIFF]
|
|
67
|
+
|
|
68
|
+
if len(args) >= 2 and args[0]:
|
|
53
69
|
result_item.extend(args[1])
|
|
54
70
|
else:
|
|
55
71
|
result_item.append(CompareConst.NONE)
|
|
@@ -58,113 +74,102 @@ class Comparator:
|
|
|
58
74
|
@staticmethod
|
|
59
75
|
def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
|
|
60
76
|
err_msg = ""
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
64
|
-
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
65
|
-
diff = npu_val - bench_val
|
|
66
|
-
if bench_val != 0:
|
|
67
|
-
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
68
|
-
else:
|
|
69
|
-
relative = "N/A"
|
|
70
|
-
result_item[start_idx + i] = diff
|
|
71
|
-
result_item[start_idx + i + 4] = relative
|
|
72
|
-
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
73
|
-
if magnitude_diff > 0.5:
|
|
74
|
-
warning_flag = True
|
|
75
|
-
else:
|
|
76
|
-
result_item[start_idx + i] = CompareConst.NONE
|
|
77
|
-
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
78
|
-
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
79
|
-
for i in range(start_idx, len(result_item)):
|
|
80
|
-
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
81
|
-
result_item[i] = f'{result_item[i]}\t'
|
|
77
|
+
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
78
|
+
bench_summary_data, err_msg)
|
|
82
79
|
result_item.append(accuracy_check)
|
|
83
80
|
result_item.append(err_msg)
|
|
84
|
-
|
|
85
|
-
@classmethod
|
|
86
|
-
def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
|
|
87
|
-
if md5_compare:
|
|
88
|
-
header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
|
|
89
|
-
elif summary_compare:
|
|
90
|
-
header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
|
|
91
|
-
else:
|
|
92
|
-
header = CompareConst.COMPARE_RESULT_HEADER[:]
|
|
93
81
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _generate_na_data(ops_all):
|
|
84
|
+
if not ops_all:
|
|
85
|
+
return {}
|
|
86
|
+
key = next(iter(ops_all))
|
|
87
|
+
value = deepcopy(ops_all[key])
|
|
88
|
+
for k, v in value.items():
|
|
89
|
+
if isinstance(v, tuple):
|
|
90
|
+
value[k] = tuple(CompareConst.N_A for _ in range(len(v)))
|
|
91
|
+
elif isinstance(v, list):
|
|
92
|
+
value[k] = [CompareConst.N_A] * len(v)
|
|
99
93
|
else:
|
|
100
|
-
|
|
94
|
+
value[k] = CompareConst.N_A
|
|
95
|
+
return value
|
|
96
|
+
|
|
97
|
+
def make_result_table(self, result):
|
|
98
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
|
|
99
|
+
|
|
100
|
+
if self.stack_mode:
|
|
101
|
+
header.append(CompareConst.STACK)
|
|
102
|
+
if self.dump_mode == Const.ALL:
|
|
103
|
+
header.append(CompareConst.DATA_NAME)
|
|
101
104
|
else:
|
|
102
|
-
if
|
|
105
|
+
if self.dump_mode == Const.ALL:
|
|
103
106
|
for row in result:
|
|
104
|
-
del row[-2]
|
|
107
|
+
del row[-2] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,真实数据时为倒数第2列
|
|
105
108
|
header.append(CompareConst.DATA_NAME)
|
|
106
109
|
else:
|
|
107
110
|
for row in result:
|
|
108
|
-
del row[-1]
|
|
111
|
+
del row[-1] # 输出结果不要堆栈信息时,删除中间结果result中的stack info,非真实数据时为倒数第1列
|
|
109
112
|
result_df = pd.DataFrame(result, columns=header, dtype='object')
|
|
110
|
-
return result_df
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
|
|
113
|
+
return result_df
|
|
114
|
+
|
|
115
|
+
def gen_merge_list(self, json_data, op_name, stack_json_data):
|
|
114
116
|
op_data = json_data['data'][op_name]
|
|
115
117
|
check_dump_json_str(op_data, op_name)
|
|
116
118
|
op_parsed_list = read_op(op_data, op_name)
|
|
117
119
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
120
|
+
if self.stack_mode:
|
|
121
|
+
stack_info = stack_json_data.get(op_name)
|
|
122
|
+
if stack_info is not None:
|
|
123
|
+
check_stack_json_str(stack_info, op_name)
|
|
124
|
+
# append only when stack_mode is True,
|
|
125
|
+
op_parsed_list.append({
|
|
126
|
+
'full_op_name': op_name,
|
|
127
|
+
'full_info': stack_info
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
merge_list = merge_tensor(op_parsed_list, self.dump_mode)
|
|
127
131
|
return merge_list
|
|
128
|
-
|
|
129
|
-
def check_op(self, npu_dict, bench_dict
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
graph_mode = check_graph_mode(
|
|
133
|
-
|
|
132
|
+
|
|
133
|
+
def check_op(self, npu_dict, bench_dict):
|
|
134
|
+
npu_op_name = npu_dict[CompareConst.OP_NAME]
|
|
135
|
+
bench_op_name = bench_dict[CompareConst.OP_NAME]
|
|
136
|
+
graph_mode = check_graph_mode(safe_get_value(npu_op_name, 0, "npu_op_name"),
|
|
137
|
+
safe_get_value(bench_op_name, 0, "bench_op_name"))
|
|
138
|
+
|
|
134
139
|
frame_name = getattr(self, "frame_name")
|
|
135
140
|
if frame_name == "PTComparator":
|
|
136
141
|
from msprobe.pytorch.compare.match import graph_mapping
|
|
137
142
|
if graph_mode:
|
|
138
|
-
return graph_mapping.match(
|
|
143
|
+
return graph_mapping.match(npu_op_name[0], bench_op_name[0])
|
|
139
144
|
struct_match = check_struct_match(npu_dict, bench_dict)
|
|
140
|
-
if not fuzzy_match:
|
|
141
|
-
|
|
142
|
-
|
|
145
|
+
if not self.fuzzy_match:
|
|
146
|
+
name_match = npu_op_name == bench_op_name
|
|
147
|
+
return name_match and struct_match
|
|
143
148
|
try:
|
|
144
|
-
|
|
149
|
+
name_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
145
150
|
except Exception as err:
|
|
146
|
-
logger.warning("%s and %s can not fuzzy match." % (
|
|
147
|
-
|
|
148
|
-
return
|
|
149
|
-
|
|
150
|
-
def match_op(self, npu_queue, bench_queue
|
|
151
|
+
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
152
|
+
name_match = False
|
|
153
|
+
return name_match and struct_match
|
|
154
|
+
|
|
155
|
+
def match_op(self, npu_queue, bench_queue):
|
|
151
156
|
for b_index, b_op in enumerate(bench_queue[0: -1]):
|
|
152
|
-
if self.check_op(npu_queue[-1], b_op
|
|
157
|
+
if self.check_op(npu_queue[-1], b_op):
|
|
153
158
|
return len(npu_queue) - 1, b_index
|
|
154
|
-
if self.check_op(npu_queue[-1], bench_queue[-1]
|
|
159
|
+
if self.check_op(npu_queue[-1], bench_queue[-1]):
|
|
155
160
|
return len(npu_queue) - 1, len(bench_queue) - 1
|
|
156
161
|
for n_index, n_op in enumerate(npu_queue[0: -1]):
|
|
157
|
-
if self.check_op(n_op, bench_queue[-1]
|
|
162
|
+
if self.check_op(n_op, bench_queue[-1]):
|
|
158
163
|
return n_index, len(bench_queue) - 1
|
|
159
164
|
return -1, -1
|
|
160
|
-
|
|
161
|
-
def compare_process(self, file_lists
|
|
165
|
+
|
|
166
|
+
def compare_process(self, file_lists):
|
|
162
167
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
163
168
|
npu_json_data = load_json(npu_json_path)
|
|
164
169
|
bench_json_data = load_json(bench_json_path)
|
|
165
|
-
stack_json_data = load_json(stack_json_path)
|
|
170
|
+
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
166
171
|
|
|
167
|
-
if fuzzy_match:
|
|
172
|
+
if self.fuzzy_match:
|
|
168
173
|
logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
|
|
169
174
|
|
|
170
175
|
npu_ops_queue = []
|
|
@@ -188,9 +193,7 @@ class Comparator:
|
|
|
188
193
|
last_npu_ops_len = len(npu_ops_queue)
|
|
189
194
|
op_name_npu = next(ops_npu_iter)
|
|
190
195
|
check_op_str_pattern_valid(op_name_npu)
|
|
191
|
-
|
|
192
|
-
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
|
|
193
|
-
summary_compare, md5_compare)
|
|
196
|
+
npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data)
|
|
194
197
|
if npu_merge_list:
|
|
195
198
|
npu_ops_queue.append(npu_merge_list)
|
|
196
199
|
except StopIteration:
|
|
@@ -199,8 +202,7 @@ class Comparator:
|
|
|
199
202
|
last_bench_ops_len = len(bench_ops_queue)
|
|
200
203
|
op_name_bench = next(ops_bench_iter)
|
|
201
204
|
check_op_str_pattern_valid(op_name_bench)
|
|
202
|
-
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data
|
|
203
|
-
summary_compare, md5_compare)
|
|
205
|
+
bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data)
|
|
204
206
|
if bench_merge_list:
|
|
205
207
|
bench_ops_queue.append(bench_merge_list)
|
|
206
208
|
except StopIteration:
|
|
@@ -219,78 +221,105 @@ class Comparator:
|
|
|
219
221
|
logger.info("Please check whether the number and calls of APIs in NPU and Bench models are consistent.")
|
|
220
222
|
break
|
|
221
223
|
|
|
222
|
-
n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue
|
|
224
|
+
n_match_point, b_match_point = self.match_op(npu_ops_queue, bench_ops_queue)
|
|
225
|
+
|
|
226
|
+
# 如果没有匹配到,数据放到队列中,跳过,直到后面匹配到,把匹配之前的api放到不匹配中
|
|
223
227
|
if n_match_point == -1 and b_match_point == -1:
|
|
224
228
|
continue
|
|
229
|
+
|
|
225
230
|
n_match_data = npu_ops_queue[n_match_point]
|
|
226
231
|
b_match_data = bench_ops_queue[b_match_point]
|
|
227
232
|
un_match_data = npu_ops_queue[0: n_match_point]
|
|
228
233
|
for npu_data in un_match_data:
|
|
229
|
-
get_un_match_accuracy(result, npu_data,
|
|
230
|
-
get_accuracy(result, n_match_data, b_match_data,
|
|
234
|
+
get_un_match_accuracy(result, npu_data, self.dump_mode)
|
|
235
|
+
get_accuracy(result, n_match_data, b_match_data, self.dump_mode)
|
|
231
236
|
del npu_ops_queue[0: n_match_point + 1]
|
|
232
237
|
del bench_ops_queue[0: b_match_point + 1]
|
|
238
|
+
progress_bar.close()
|
|
233
239
|
if npu_ops_queue:
|
|
234
240
|
for npu_data in npu_ops_queue:
|
|
235
|
-
get_un_match_accuracy(result, npu_data,
|
|
236
|
-
|
|
237
|
-
result_df = self.make_result_table(result
|
|
241
|
+
get_un_match_accuracy(result, npu_data, self.dump_mode)
|
|
242
|
+
|
|
243
|
+
result_df = self.make_result_table(result)
|
|
238
244
|
return result_df
|
|
239
245
|
|
|
240
|
-
def merge_data(self, json_data, stack_json_data
|
|
246
|
+
def merge_data(self, json_data, stack_json_data):
|
|
241
247
|
ops_all = {}
|
|
242
248
|
for op_name in json_data.get('data', {}):
|
|
243
|
-
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data
|
|
244
|
-
md5_compare)
|
|
249
|
+
merge_list = self.gen_merge_list(json_data, op_name, stack_json_data)
|
|
245
250
|
if merge_list:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
251
|
+
struct_to_index_mapping = {
|
|
252
|
+
CompareConst.INPUT_STRUCT: 0,
|
|
253
|
+
CompareConst.OUTPUT_STRUCT: 0,
|
|
254
|
+
CompareConst.PARAMS_STRUCT: 0,
|
|
255
|
+
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
259
|
+
summary_list = merge_list.get(Const.SUMMARY)
|
|
260
|
+
data_name_list = merge_list.get('data_name')
|
|
261
|
+
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
262
|
+
summary_list,
|
|
263
|
+
data_name_list)
|
|
264
|
+
for index, op_full_name in enumerate(op_name_reorder):
|
|
265
|
+
data_name = data_name_reorder[index] if data_name_reorder else None
|
|
266
|
+
|
|
267
|
+
_, state = get_name_and_state(op_full_name)
|
|
268
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
269
|
+
if not struct_key:
|
|
270
|
+
continue
|
|
271
|
+
ops_all[op_full_name] = {
|
|
272
|
+
CompareConst.STRUCT: safe_get_value(merge_list, struct_to_index_mapping.get(struct_key),
|
|
273
|
+
"merge_list", key=struct_key),
|
|
274
|
+
CompareConst.SUMMARY: safe_get_value(summary_reorder, index, "summary_reorder"),
|
|
275
|
+
'data_name': data_name,
|
|
276
|
+
'stack_info': merge_list.get('stack_info')
|
|
277
|
+
}
|
|
278
|
+
struct_to_index_mapping[struct_key] += 1
|
|
264
279
|
return ops_all
|
|
265
280
|
|
|
266
|
-
def get_accuracy(self, npu_ops_all, bench_ops_all
|
|
281
|
+
def get_accuracy(self, npu_ops_all, bench_ops_all):
|
|
267
282
|
result = []
|
|
283
|
+
bench_ops_all[CompareConst.N_A] = self._generate_na_data(bench_ops_all)
|
|
268
284
|
for ms_op_name, bench_op_name in self.data_mapping_dict.items():
|
|
269
285
|
if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
|
|
270
286
|
npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
|
|
271
287
|
bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
|
|
272
288
|
has_stack = npu_stack_info and bench_stack_info
|
|
273
|
-
if
|
|
289
|
+
if self.dump_mode == Const.MD5:
|
|
274
290
|
result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
|
|
275
291
|
bench_ops_all, has_stack, npu_stack_info))
|
|
276
292
|
continue
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
293
|
+
|
|
294
|
+
npu_struct = npu_ops_all.get(ms_op_name).get('struct', [])
|
|
295
|
+
bench_struct = bench_ops_all.get(bench_op_name).get('struct', [])
|
|
296
|
+
|
|
297
|
+
if len(npu_struct) < 2 or len(bench_struct) < 2:
|
|
298
|
+
logger.error(
|
|
299
|
+
f"The length of npu_struct and bench_struct must be >= 2, "
|
|
300
|
+
f"but got npu_struct={len(npu_struct)} and bench_struct={len(bench_struct)}. "
|
|
301
|
+
f"Please check!"
|
|
302
|
+
)
|
|
303
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
304
|
+
|
|
305
|
+
base_result_item = [
|
|
306
|
+
ms_op_name, bench_op_name,
|
|
307
|
+
npu_struct[0],
|
|
308
|
+
bench_struct[0],
|
|
309
|
+
npu_struct[1],
|
|
310
|
+
bench_struct[1]
|
|
311
|
+
]
|
|
312
|
+
|
|
313
|
+
if self.dump_mode == Const.SUMMARY:
|
|
314
|
+
result_item = base_result_item + [" "] * 8
|
|
283
315
|
else:
|
|
284
|
-
result_item =
|
|
285
|
-
|
|
286
|
-
npu_ops_all.get(ms_op_name).get('struct')[1],
|
|
287
|
-
bench_ops_all.get(bench_op_name).get('struct')[1],
|
|
288
|
-
" ", " ", " ", " ", " "]
|
|
316
|
+
result_item = base_result_item + [" "] * 5
|
|
317
|
+
|
|
289
318
|
npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
|
|
290
319
|
result_item.extend(npu_summary_data)
|
|
291
320
|
bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
|
|
292
321
|
result_item.extend(bench_summary_data)
|
|
293
|
-
if
|
|
322
|
+
if self.dump_mode == Const.SUMMARY:
|
|
294
323
|
self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
|
|
295
324
|
else:
|
|
296
325
|
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
@@ -299,7 +328,7 @@ class Comparator:
|
|
|
299
328
|
result_item.extend(npu_stack_info)
|
|
300
329
|
else:
|
|
301
330
|
result_item.append(CompareConst.NONE)
|
|
302
|
-
if
|
|
331
|
+
if self.dump_mode == Const.ALL:
|
|
303
332
|
result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
|
|
304
333
|
result.append(result_item)
|
|
305
334
|
elif ms_op_name not in npu_ops_all:
|
|
@@ -308,26 +337,39 @@ class Comparator:
|
|
|
308
337
|
logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
|
|
309
338
|
return result
|
|
310
339
|
|
|
311
|
-
def compare_process_custom(self, file_lists
|
|
340
|
+
def compare_process_custom(self, file_lists):
|
|
312
341
|
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
313
342
|
npu_json_data = load_json(npu_json_path)
|
|
314
343
|
bench_json_data = load_json(bench_json_path)
|
|
315
|
-
stack_json_data = load_json(stack_json_path)
|
|
344
|
+
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
345
|
+
npu_ops_all = self.merge_data(npu_json_data, stack_json_data)
|
|
346
|
+
bench_ops_all = self.merge_data(bench_json_data, stack_json_data)
|
|
316
347
|
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
|
|
321
|
-
result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
|
|
348
|
+
result = self.get_accuracy(npu_ops_all, bench_ops_all)
|
|
349
|
+
result_df = self.make_result_table(result)
|
|
322
350
|
return result_df
|
|
323
351
|
|
|
324
|
-
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
|
|
352
|
+
def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param, bench_data):
|
|
353
|
+
"""
|
|
354
|
+
:param npu_op_name: excel中的NPU_Name,例如:MintFunctional.conv2d.0.forward.input.3.0
|
|
355
|
+
:param bench_op_name: excel中的Bench_Name,例如:Functional.conv2d.0.forward.input.3.0
|
|
356
|
+
:param op_name_mapping_dict: op_name和npy或pt文件的映射关系
|
|
357
|
+
:param input_param: npu_json_path/bench_json_path/stack_json_path等参数
|
|
358
|
+
:param bench_data: bench的dump数据中"data"字段
|
|
359
|
+
:return: result_list,包含余弦相似度、最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率和错误信息
|
|
360
|
+
用于读取excel中的NPU_Name和Bench_Name,根据映射关系找到npy或pt文件,然后读取文件中的数据进行比较,计算余弦相似度、
|
|
361
|
+
最大绝对误差、最大相对误差、千分之一误差率、千分之五误差率并生成错误信息
|
|
362
|
+
"""
|
|
325
363
|
npu_bench_name_list = op_name_mapping_dict[npu_op_name]
|
|
326
|
-
data_name = npu_bench_name_list
|
|
364
|
+
data_name = safe_get_value(npu_bench_name_list, 1, "npu_bench_name_list")
|
|
327
365
|
error_file, relative_err, error_flag = None, None, False
|
|
366
|
+
bench_data_name = get_bench_data_name(bench_op_name, bench_data)
|
|
328
367
|
if data_name == '-1' or data_name == -1: # 没有真实数据路径
|
|
329
368
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
330
369
|
error_flag = True
|
|
370
|
+
elif not bench_data_name:
|
|
371
|
+
n_value, b_value, error_flag = CompareConst.READ_NONE, CompareConst.READ_NONE, True
|
|
372
|
+
error_file = 'no_bench_data'
|
|
331
373
|
else:
|
|
332
374
|
try:
|
|
333
375
|
read_npy_data = getattr(self, "read_npy_data")
|
|
@@ -335,42 +377,39 @@ class Comparator:
|
|
|
335
377
|
if frame_name == "MSComparator":
|
|
336
378
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
|
|
337
379
|
if self.cross_frame:
|
|
338
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
339
|
-
|
|
380
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name,
|
|
381
|
+
load_pt_file=True)
|
|
340
382
|
else:
|
|
341
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
342
|
-
bench_op_name + Const.NUMPY_SUFFIX)
|
|
383
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
|
|
343
384
|
else:
|
|
344
385
|
n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
|
|
345
|
-
b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
|
|
386
|
+
b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_data_name)
|
|
346
387
|
except IOError as error:
|
|
347
388
|
error_file = error.filename
|
|
348
389
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
349
390
|
error_flag = True
|
|
350
|
-
except FileCheckException:
|
|
391
|
+
except (FileCheckException, CompareException):
|
|
351
392
|
error_file = data_name
|
|
352
393
|
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
|
|
353
394
|
error_flag = True
|
|
354
395
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
n_value, b_value = reshape_value(n_value, b_value)
|
|
396
|
+
# 通过n_value, b_value同时得到错误标志和错误信息
|
|
397
|
+
n_value, b_value, error_flag, err_msg = get_error_flag_and_msg(n_value, b_value,
|
|
398
|
+
error_flag=error_flag, error_file=error_file)
|
|
359
399
|
|
|
360
|
-
err_msg =
|
|
361
|
-
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err)
|
|
400
|
+
result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg)
|
|
362
401
|
|
|
363
|
-
if npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
402
|
+
if self.fuzzy_match and npu_op_name != bench_op_name and bench_op_name != CompareConst.N_A:
|
|
364
403
|
err_msg += " Fuzzy matching data, the comparison accuracy may be affected."
|
|
365
404
|
result_list.append(err_msg)
|
|
366
405
|
return result_list
|
|
367
|
-
|
|
368
|
-
def compare_core(self,
|
|
406
|
+
|
|
407
|
+
def compare_core(self, input_param, output_path, **kwargs):
|
|
369
408
|
"""
|
|
370
409
|
Compares data from multiple JSON files and generates a comparison report.
|
|
371
410
|
|
|
372
411
|
Args:
|
|
373
|
-
|
|
412
|
+
input_param (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path",
|
|
374
413
|
"stack_path").
|
|
375
414
|
output_path (str): The path where the output Excel report will be saved.
|
|
376
415
|
**kwargs: Additional keyword arguments including:
|
|
@@ -378,51 +417,43 @@ class Comparator:
|
|
|
378
417
|
- auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True.
|
|
379
418
|
- suffix (str, optional): Suffix to append to the output file name. Defaults to ''.
|
|
380
419
|
- fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False.
|
|
381
|
-
-
|
|
382
|
-
- md5_compare (bool, optional): Enables MD5 comparison. Defaults to False.
|
|
420
|
+
- dump_mode (str): ALL, SUMMARY, MD5.
|
|
383
421
|
|
|
384
422
|
Returns:
|
|
385
423
|
"""
|
|
386
424
|
# get kwargs or set default value
|
|
387
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
388
|
-
auto_analyze = kwargs.get('auto_analyze', True)
|
|
389
425
|
suffix = kwargs.get('suffix', '')
|
|
390
|
-
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
391
|
-
summary_compare = kwargs.get('summary_compare', False)
|
|
392
|
-
md5_compare = kwargs.get('md5_compare', False)
|
|
393
426
|
|
|
394
427
|
logger.info("Please check whether the input data belongs to you. If not, there may be security risks.")
|
|
395
428
|
file_name = add_time_with_xlsx("compare_result" + suffix)
|
|
396
429
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
397
430
|
remove_path(file_path)
|
|
398
|
-
highlight_dict = {
|
|
431
|
+
highlight_dict = {"red_rows": set(), "yellow_rows": set(), "red_lines": [], "yellow_lines": []}
|
|
399
432
|
|
|
400
|
-
npu_json =
|
|
401
|
-
bench_json =
|
|
402
|
-
stack_json =
|
|
433
|
+
npu_json = input_param.get("npu_json_path")
|
|
434
|
+
bench_json = input_param.get("bench_json_path")
|
|
435
|
+
stack_json = input_param.get("stack_json_path")
|
|
403
436
|
if self.data_mapping:
|
|
404
|
-
result_df = self.compare_process_custom([npu_json, bench_json, stack_json]
|
|
405
|
-
summary_compare, md5_compare)
|
|
437
|
+
result_df = self.compare_process_custom([npu_json, bench_json, stack_json])
|
|
406
438
|
else:
|
|
407
|
-
result_df = self.compare_process([npu_json, bench_json, stack_json]
|
|
408
|
-
summary_compare, md5_compare)
|
|
439
|
+
result_df = self.compare_process([npu_json, bench_json, stack_json])
|
|
409
440
|
|
|
410
441
|
if not result_df.values.tolist():
|
|
411
442
|
logger.warning("Can`t match any op.")
|
|
412
443
|
return
|
|
413
444
|
|
|
414
|
-
if
|
|
415
|
-
result_df = self.
|
|
445
|
+
if self.dump_mode == Const.ALL:
|
|
446
|
+
result_df = self.do_multi_process(input_param, result_df)
|
|
416
447
|
|
|
417
|
-
|
|
418
|
-
find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
|
|
448
|
+
find_compare_result_error_rows(result_df, highlight_dict, self.dump_mode)
|
|
419
449
|
highlight_rows_xlsx(result_df, highlight_dict, file_path)
|
|
420
|
-
logger.info("Highlight suspicious API/Module finish.")
|
|
421
450
|
|
|
422
|
-
if auto_analyze:
|
|
451
|
+
if self.auto_analyze:
|
|
423
452
|
advisor = Advisor(result_df, output_path, suffix)
|
|
424
453
|
advisor.analysis()
|
|
425
|
-
|
|
454
|
+
|
|
455
|
+
print_compare_ends_info()
|
|
456
|
+
|
|
426
457
|
def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
|
|
427
458
|
cos_result = []
|
|
428
459
|
max_err_result = []
|
|
@@ -431,13 +462,16 @@ class Comparator:
|
|
|
431
462
|
one_thousand_err_ratio_result = []
|
|
432
463
|
five_thousand_err_ratio_result = []
|
|
433
464
|
is_print_compare_log = input_param.get("is_print_compare_log")
|
|
465
|
+
bench_data = load_json(input_param.get("bench_json_path")).get('data')
|
|
434
466
|
for i in range(len(result_df)):
|
|
435
467
|
npu_op_name = result_df.iloc[i, 0]
|
|
436
468
|
bench_op_name = result_df.iloc[i, 1]
|
|
437
469
|
if is_print_compare_log:
|
|
438
470
|
logger.info("start compare: {}".format(npu_op_name))
|
|
471
|
+
|
|
439
472
|
cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
|
|
440
|
-
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
|
|
473
|
+
self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param, bench_data)
|
|
474
|
+
|
|
441
475
|
if is_print_compare_log:
|
|
442
476
|
logger.info(
|
|
443
477
|
"[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
|
|
@@ -460,9 +494,9 @@ class Comparator:
|
|
|
460
494
|
five_thousand_err_ratio_result=five_thousand_err_ratio_result
|
|
461
495
|
)
|
|
462
496
|
|
|
463
|
-
return _save_cmp_result(idx, cr, result_df, lock)
|
|
464
|
-
|
|
465
|
-
def
|
|
497
|
+
return _save_cmp_result(idx, cr, result_df, lock)
|
|
498
|
+
|
|
499
|
+
def do_multi_process(self, input_parma, result_df):
|
|
466
500
|
try:
|
|
467
501
|
result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
|
|
468
502
|
multiprocessing.Manager().RLock())
|
|
@@ -470,4 +504,46 @@ class Comparator:
|
|
|
470
504
|
except ValueError as e:
|
|
471
505
|
logger.error('result dataframe is not found.')
|
|
472
506
|
raise CompareException(CompareException.INVALID_DATA_ERROR) from e
|
|
473
|
-
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def get_bench_data_name(bench_op_name, bench_data):
|
|
510
|
+
bench_name_list = re.split(r'\.(input|output|kwargs|parameters|parameters_grad)\.', bench_op_name)
|
|
511
|
+
if len(bench_name_list) > 1 and bench_name_list[1] == Const.PARAMS_GRAD:
|
|
512
|
+
bench_data_bundle = bench_data.get(bench_name_list[0] + Const.SEP + bench_name_list[1], {})
|
|
513
|
+
else:
|
|
514
|
+
bench_data_bundle = bench_data.get(bench_name_list[0], {})
|
|
515
|
+
if not bench_data_bundle or len(bench_name_list) < 3:
|
|
516
|
+
return None
|
|
517
|
+
layers = bench_name_list[2].split(Const.SEP)
|
|
518
|
+
|
|
519
|
+
def _get(key, container):
|
|
520
|
+
if isinstance(container, dict):
|
|
521
|
+
return container.get(key)
|
|
522
|
+
if isinstance(container, list):
|
|
523
|
+
try:
|
|
524
|
+
return container[int(key)]
|
|
525
|
+
except (ValueError, IndexError):
|
|
526
|
+
return None
|
|
527
|
+
return None
|
|
528
|
+
|
|
529
|
+
def get_by_layer(container, params_grad=False):
|
|
530
|
+
data = container
|
|
531
|
+
# dump.json中parameters_grad的结构为key:[{}], 如果存在key,有且只有一个列表元素,而op_name中只命名到了key,因此加'0'
|
|
532
|
+
if params_grad:
|
|
533
|
+
layers.append('0')
|
|
534
|
+
for layer in layers:
|
|
535
|
+
data = _get(layer, data)
|
|
536
|
+
return _get(CompareConst.DATA_NAME.lower(), data)
|
|
537
|
+
|
|
538
|
+
if Const.INPUT == bench_name_list[1]:
|
|
539
|
+
return get_by_layer(bench_data_bundle.get(Const.INPUT, bench_data_bundle.get(Const.INPUT_ARGS)))
|
|
540
|
+
elif Const.KWARGS == bench_name_list[1]:
|
|
541
|
+
return get_by_layer(bench_data_bundle.get(Const.INPUT_KWARGS))
|
|
542
|
+
elif Const.OUTPUT == bench_name_list[1]:
|
|
543
|
+
return get_by_layer(bench_data_bundle.get(Const.OUTPUT))
|
|
544
|
+
elif Const.PARAMS == bench_name_list[1]:
|
|
545
|
+
return get_by_layer(bench_data_bundle.get(Const.PARAMS))
|
|
546
|
+
elif Const.PARAMS_GRAD == bench_name_list[1]:
|
|
547
|
+
return get_by_layer(bench_data_bundle, params_grad=True)
|
|
548
|
+
else:
|
|
549
|
+
return None
|