mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,96 +13,129 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import math
|
|
17
16
|
import abc
|
|
17
|
+
import math
|
|
18
|
+
import multiprocessing
|
|
18
19
|
import re
|
|
19
20
|
from collections import namedtuple
|
|
21
|
+
|
|
20
22
|
import numpy as np
|
|
21
23
|
import openpyxl
|
|
22
24
|
from openpyxl.styles import PatternFill
|
|
23
|
-
from
|
|
25
|
+
from openpyxl.utils.dataframe import dataframe_to_rows
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
24
29
|
from msprobe.core.common.file_utils import save_workbook
|
|
25
30
|
from msprobe.core.common.log import logger
|
|
26
|
-
from msprobe.core.common.
|
|
31
|
+
from msprobe.core.common.utils import get_header_index, safe_get_value
|
|
32
|
+
from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
class HighlightCheck(abc.ABC):
|
|
30
36
|
@abc.abstractmethod
|
|
31
|
-
def apply(self, info, color_columns,
|
|
37
|
+
def apply(self, info, color_columns, dump_mode):
|
|
32
38
|
raise NotImplementedError
|
|
33
39
|
|
|
34
40
|
|
|
41
|
+
def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
42
|
+
for i, (existing_num, existing_err_msg) in enumerate(color_list):
|
|
43
|
+
if num == existing_num:
|
|
44
|
+
color_list[i][1].append(highlight_err_msg)
|
|
45
|
+
return
|
|
46
|
+
color_list.append((num, [highlight_err_msg]))
|
|
47
|
+
|
|
48
|
+
|
|
35
49
|
class CheckOrderMagnitude(HighlightCheck):
|
|
36
50
|
"""检查Max diff的数量级差异"""
|
|
37
|
-
|
|
51
|
+
|
|
52
|
+
def apply(self, info, color_columns, dump_mode):
|
|
38
53
|
api_in, api_out, num = info
|
|
39
|
-
max_diff_index = get_header_index(
|
|
54
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
55
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
40
56
|
if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
|
|
41
57
|
return
|
|
42
58
|
in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
|
|
43
59
|
out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
|
|
44
60
|
if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
|
|
45
|
-
color_columns.yellow
|
|
61
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
62
|
+
"maximum absolute error of both input/parameters and output exceed 1, "
|
|
63
|
+
"with the output larger by an order of magnitude")
|
|
46
64
|
|
|
47
65
|
|
|
48
66
|
class CheckOneThousandErrorRatio(HighlightCheck):
|
|
49
67
|
"""检查千分误差比率"""
|
|
50
|
-
|
|
68
|
+
|
|
69
|
+
def apply(self, info, color_columns, dump_mode):
|
|
51
70
|
api_in, api_out, num = info
|
|
52
|
-
one_thousand_index = get_header_index(
|
|
71
|
+
one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
|
|
53
72
|
if (not isinstance(api_in[one_thousand_index], (float, int)) or
|
|
54
73
|
not isinstance(api_out[one_thousand_index], (float, int))):
|
|
55
74
|
return
|
|
56
75
|
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
57
76
|
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
58
|
-
color_columns.red
|
|
77
|
+
add_highlight_row_info(color_columns.red, num,
|
|
78
|
+
"The input/parameters's one thousandth err ratio exceeds 0.9, "
|
|
79
|
+
"while the output's is below 0.6")
|
|
59
80
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
60
|
-
color_columns.yellow
|
|
81
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
82
|
+
"The output's one thousandth err ratio decreases by more than 0.1 "
|
|
83
|
+
"compared to the input/parameters's")
|
|
61
84
|
|
|
62
85
|
|
|
63
86
|
class CheckCosineSimilarity(HighlightCheck):
|
|
64
87
|
"""检查余弦相似度"""
|
|
65
|
-
|
|
88
|
+
|
|
89
|
+
def apply(self, info, color_columns, dump_mode):
|
|
66
90
|
api_in, api_out, num = info
|
|
67
|
-
cosine_index = get_header_index(
|
|
91
|
+
cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
|
|
68
92
|
if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
|
|
69
93
|
return
|
|
70
94
|
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
71
|
-
color_columns.yellow
|
|
95
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
96
|
+
"The output's cosine decreases by more than 0.1 "
|
|
97
|
+
"compared to the input/parameters's")
|
|
72
98
|
|
|
73
99
|
|
|
74
100
|
class CheckMaxRelativeDiff(HighlightCheck):
|
|
75
101
|
"""检查最大相对差异"""
|
|
76
|
-
|
|
102
|
+
|
|
103
|
+
def apply(self, info, color_columns, dump_mode):
|
|
77
104
|
api_in, api_out, num = info
|
|
78
|
-
max_diff_index = get_header_index(
|
|
79
|
-
bench_max_index = get_header_index(
|
|
80
|
-
input_max_relative_diff = np.abs(
|
|
81
|
-
|
|
105
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
|
|
106
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
107
|
+
input_max_relative_diff = np.abs(
|
|
108
|
+
np.divide(api_in[max_diff_index], max(Const.FLOAT_EPSILON, api_in[bench_max_index])))
|
|
109
|
+
output_max_relative_diff = np.abs(
|
|
110
|
+
np.divide(api_out[max_diff_index], max(Const.FLOAT_EPSILON, api_out[bench_max_index])))
|
|
82
111
|
if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
|
|
83
112
|
(float, int)):
|
|
84
113
|
return
|
|
85
114
|
if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
|
|
86
|
-
color_columns.red.
|
|
115
|
+
add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
|
|
87
116
|
elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
88
117
|
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
89
|
-
color_columns.yellow
|
|
118
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
119
|
+
"The output's maximum relative error exceeds 0.1, "
|
|
120
|
+
"while the input/parameters's is below 0.01")
|
|
90
121
|
|
|
91
122
|
|
|
92
123
|
class CheckOverflow(HighlightCheck):
|
|
93
124
|
"""检查是否存在溢出"""
|
|
94
|
-
|
|
125
|
+
|
|
126
|
+
def apply(self, info, color_columns, dump_mode):
|
|
95
127
|
line, num = info
|
|
96
|
-
npu_max_index = get_header_index(
|
|
97
|
-
npu_min_index = get_header_index(
|
|
98
|
-
max_diff_index = get_header_index(
|
|
128
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
129
|
+
npu_min_index = get_header_index(CompareConst.NPU_MIN, dump_mode)
|
|
130
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
131
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
99
132
|
if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
|
|
100
133
|
line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
|
|
101
|
-
color_columns.red
|
|
134
|
+
add_highlight_row_info(color_columns.red, num, "maximum or minimum is nan, -inf, or inf")
|
|
102
135
|
return
|
|
103
136
|
# check if Max_Diff > 1e+10
|
|
104
|
-
if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
|
|
105
|
-
color_columns.red
|
|
137
|
+
if isinstance(line[max_diff_index], (float, int)) and abs(line[max_diff_index]) > CompareConst.MAX_DIFF_RED:
|
|
138
|
+
add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
|
|
106
139
|
|
|
107
140
|
|
|
108
141
|
class HighlightRules:
|
|
@@ -122,15 +155,31 @@ class HighlightRules:
|
|
|
122
155
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
123
156
|
"check_max_relative_diff": CheckMaxRelativeDiff(),
|
|
124
157
|
}
|
|
125
|
-
|
|
126
158
|
|
|
127
|
-
|
|
159
|
+
|
|
160
|
+
def check_indices_numeric(api_items, indices: list):
|
|
161
|
+
"""检查指定索引处的值是否都为数字类型(int 或 float)"""
|
|
162
|
+
return all(isinstance(api_items[i], (float, int)) for i in indices)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def apply_comparison_rules(api_info, dump_mode, color_columns):
|
|
166
|
+
"""output与input/params的比较"""
|
|
167
|
+
if dump_mode == Const.SUMMARY:
|
|
168
|
+
for rule in HighlightRules.summary_compare_rules.values():
|
|
169
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
170
|
+
else:
|
|
171
|
+
for rule in HighlightRules.compare_rules.values():
|
|
172
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def find_error_rows(result, api_batch, highlight_dict, dump_mode):
|
|
128
176
|
"""找到单个API中需要高亮的行"""
|
|
129
|
-
if
|
|
177
|
+
if dump_mode == Const.MD5:
|
|
130
178
|
return
|
|
131
|
-
npu_max_index = get_header_index(
|
|
132
|
-
bench_max_index = get_header_index(
|
|
133
|
-
max_diff_index = get_header_index(
|
|
179
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
180
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
181
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
182
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
134
183
|
|
|
135
184
|
red_lines, yellow_lines = [], []
|
|
136
185
|
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
@@ -138,122 +187,229 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
|
|
|
138
187
|
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
139
188
|
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
140
189
|
|
|
190
|
+
api_batch_start = api_batch.start # result_df的input起始全局索引
|
|
191
|
+
api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
|
|
192
|
+
api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
|
|
193
|
+
api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
|
|
194
|
+
api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
|
|
195
|
+
|
|
141
196
|
# 对单行API的输入或输出进行误差判断
|
|
142
197
|
for i, line in enumerate(result):
|
|
143
|
-
|
|
144
|
-
line_info = LineInfo(line_data=line, num_pointer=
|
|
198
|
+
index = api_batch_start + i
|
|
199
|
+
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
145
200
|
for rule in HighlightRules.basic_rules.values():
|
|
146
|
-
rule.apply(line_info, color_columns,
|
|
201
|
+
rule.apply(line_info, color_columns, dump_mode)
|
|
147
202
|
|
|
148
203
|
# 对API的输出与输入比较,进行误差判断
|
|
149
|
-
for n, api_out in enumerate(result[
|
|
150
|
-
|
|
151
|
-
|
|
204
|
+
for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
|
|
205
|
+
index = api_batch_start + api_batch_params_slice_index_local + n
|
|
206
|
+
# 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
|
|
207
|
+
if index in red_lines:
|
|
152
208
|
continue
|
|
153
|
-
if not
|
|
154
|
-
or not isinstance(api_out[bench_max_index], (float, int)) \
|
|
155
|
-
or not isinstance(api_out[max_diff_index], (float, int)):
|
|
209
|
+
if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
|
|
156
210
|
continue
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
211
|
+
|
|
212
|
+
# input/parameters的比较检查, 这里api_in包括input、parameters
|
|
213
|
+
for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
|
|
214
|
+
if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
|
|
161
215
|
continue
|
|
216
|
+
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
217
|
+
apply_comparison_rules(api_info, dump_mode, color_columns)
|
|
218
|
+
|
|
219
|
+
red_lines_num_set = {x[0] for x in red_lines}
|
|
220
|
+
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
221
|
+
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
222
|
+
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
223
|
+
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
224
|
+
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class ApiBatch:
|
|
228
|
+
def __init__(self, api_name: str, start: int):
|
|
229
|
+
self.api_name = api_name
|
|
230
|
+
self.start = start
|
|
231
|
+
self.input_len = 1 # input的数量
|
|
232
|
+
self.params_end_index = start + 1 # params的结束index
|
|
233
|
+
self.output_end_index = start + 1 # output的结束index
|
|
234
|
+
self.params_grad_end_index = start + 1 # params_grad的结束index
|
|
235
|
+
# 内部state的标志("input", "output", "parameters", "parameters_grad"),
|
|
236
|
+
# 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index
|
|
237
|
+
self._state = Const.INPUT # api_batch初始化为input
|
|
238
|
+
|
|
239
|
+
def set_state(self, state: str):
|
|
240
|
+
"""设置当前状态"""
|
|
241
|
+
if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}:
|
|
242
|
+
self._state = state
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError(f"Invalid state: {state}")
|
|
245
|
+
|
|
246
|
+
def increment(self, state: str):
|
|
247
|
+
self.set_state(state)
|
|
248
|
+
if self._state == Const.INPUT or self._state == Const.KWARGS:
|
|
249
|
+
self.input_len += 1
|
|
250
|
+
self.params_end_index += 1
|
|
251
|
+
self.output_end_index += 1
|
|
252
|
+
if self._state == Const.PARAMS:
|
|
253
|
+
self.params_end_index += 1
|
|
254
|
+
self.output_end_index += 1
|
|
255
|
+
if self._state == Const.OUTPUT:
|
|
256
|
+
self.output_end_index += 1
|
|
257
|
+
self.params_grad_end_index += 1
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def api_batches_update(api_batches, api_name, state, index):
|
|
261
|
+
"""
|
|
262
|
+
当一个api的所有item更新完后,input, output的索引范围:
|
|
263
|
+
input: [start: start+input_len]
|
|
264
|
+
output: [start+input_len: output_end_index]
|
|
265
|
+
params: [output_end_index: params_end_index]
|
|
266
|
+
"""
|
|
267
|
+
if not api_batches:
|
|
268
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
269
|
+
else:
|
|
270
|
+
api_batch = api_batches[-1]
|
|
271
|
+
if api_batch.api_name == api_name or (
|
|
272
|
+
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
273
|
+
try:
|
|
274
|
+
api_batch.increment(state)
|
|
275
|
+
except ValueError as e:
|
|
276
|
+
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
277
|
+
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
278
|
+
else:
|
|
279
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
162
280
|
|
|
163
|
-
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
|
|
164
|
-
if summary_compare:
|
|
165
|
-
for rule in HighlightRules.summary_compare_rules.values():
|
|
166
|
-
rule.apply(api_info, color_columns, summary_compare)
|
|
167
|
-
else:
|
|
168
|
-
for rule in HighlightRules.compare_rules.values():
|
|
169
|
-
rule.apply(api_info, color_columns, summary_compare)
|
|
170
281
|
|
|
171
|
-
|
|
172
|
-
|
|
282
|
+
def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
|
|
283
|
+
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
284
|
+
result = result_df.values
|
|
285
|
+
api_batches = []
|
|
286
|
+
for i, res_i in enumerate(result):
|
|
287
|
+
api_full_name = safe_get_value(res_i, 0, "res_i")
|
|
288
|
+
api_name, state = get_name_and_state(api_full_name)
|
|
289
|
+
api_batches_update(api_batches, api_name, state, i)
|
|
290
|
+
with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
|
|
291
|
+
for api_batch in api_batches:
|
|
292
|
+
find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
|
|
293
|
+
dump_mode)
|
|
294
|
+
progress_bar.update(1)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def value_check(value, api_name=None, i=None, result_df_columns=None):
|
|
298
|
+
if not table_value_is_valid(value):
|
|
299
|
+
if result_df_columns:
|
|
300
|
+
logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
|
|
301
|
+
f"is not allowed to be written into the compare result xlsx.")
|
|
302
|
+
else:
|
|
303
|
+
logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def df_malicious_value_check(df_chunk, result_df_columns):
|
|
307
|
+
for row in df_chunk.itertuples(index=False):
|
|
308
|
+
api_name = row[0]
|
|
309
|
+
for i, value in enumerate(row):
|
|
310
|
+
value_check(value, api_name, i, result_df_columns)
|
|
173
311
|
|
|
174
312
|
|
|
175
|
-
def
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
313
|
+
def handle_multi_process_malicious_value_check(func, result_df):
|
|
314
|
+
result_total_nums = len(result_df)
|
|
315
|
+
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
316
|
+
|
|
317
|
+
if result_total_nums <= process_num:
|
|
318
|
+
process_num = 1
|
|
319
|
+
chunks = [result_df]
|
|
180
320
|
else:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
return api_name, state
|
|
321
|
+
chunk_size = result_total_nums // process_num
|
|
322
|
+
chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
|
|
184
323
|
|
|
324
|
+
pool = multiprocessing.Pool(process_num)
|
|
325
|
+
|
|
326
|
+
def err_call(args):
|
|
327
|
+
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
328
|
+
try:
|
|
329
|
+
pool.terminate()
|
|
330
|
+
except OSError:
|
|
331
|
+
logger.error("Pool terminate failed")
|
|
332
|
+
|
|
333
|
+
result_df_columns = result_df.columns.tolist()
|
|
334
|
+
for column in result_df_columns:
|
|
335
|
+
value_check(column)
|
|
336
|
+
for df_chunk in chunks:
|
|
337
|
+
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
338
|
+
|
|
339
|
+
pool.close()
|
|
340
|
+
pool.join()
|
|
185
341
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
api_name, state = get_name_and_state(res_i[0])
|
|
194
|
-
if last_api_name:
|
|
195
|
-
if api_name == last_api_name:
|
|
196
|
-
if state == last_state:
|
|
197
|
-
num += 1
|
|
198
|
-
else:
|
|
199
|
-
input_num = num
|
|
200
|
-
num, last_state = 1, state
|
|
201
|
-
else:
|
|
202
|
-
output_num = num
|
|
203
|
-
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
204
|
-
summary_compare, md5_compare)
|
|
205
|
-
num, last_api_name, last_state = 1, api_name, state
|
|
206
|
-
start += input_num + output_num
|
|
207
|
-
input_num, output_num = 1, 0
|
|
208
|
-
else:
|
|
209
|
-
num, last_api_name, last_state = 1, api_name, state
|
|
210
|
-
if state:
|
|
211
|
-
if state == "input":
|
|
212
|
-
input_num = num
|
|
213
|
-
else:
|
|
214
|
-
output_num = num
|
|
215
|
-
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
216
|
-
summary_compare, md5_compare)
|
|
342
|
+
|
|
343
|
+
def compare_result_df_convert(value):
|
|
344
|
+
if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
|
|
345
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
|
|
346
|
+
if isinstance(value, float):
|
|
347
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
|
|
348
|
+
return value
|
|
217
349
|
|
|
218
350
|
|
|
219
351
|
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
220
352
|
"""Write and highlight results in Excel"""
|
|
221
|
-
|
|
353
|
+
|
|
354
|
+
update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
222
355
|
|
|
223
356
|
wb = openpyxl.Workbook()
|
|
224
357
|
ws = wb.active
|
|
225
358
|
|
|
226
359
|
# write header
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
360
|
+
logger.info('Initializing Excel file.')
|
|
361
|
+
|
|
362
|
+
handle_multi_process_malicious_value_check(df_malicious_value_check, result_df)
|
|
363
|
+
|
|
364
|
+
result_df_convert = result_df.applymap(compare_result_df_convert)
|
|
365
|
+
|
|
366
|
+
for row in dataframe_to_rows(result_df_convert, index=False, header=True):
|
|
367
|
+
ws.append(row)
|
|
368
|
+
|
|
369
|
+
# 对可疑数据标色
|
|
370
|
+
logger.info('Coloring Excel in progress.')
|
|
371
|
+
col_len = len(result_df.columns)
|
|
372
|
+
red_fill = PatternFill(
|
|
373
|
+
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
374
|
+
)
|
|
375
|
+
yellow_fill = PatternFill(
|
|
376
|
+
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
377
|
+
)
|
|
378
|
+
for i in highlight_dict.get("red_rows", []):
|
|
379
|
+
for j in range(1, col_len + 1):
|
|
380
|
+
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
381
|
+
for i in highlight_dict.get("yellow_rows", []):
|
|
382
|
+
for j in range(1, col_len + 1):
|
|
383
|
+
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
384
|
+
|
|
385
|
+
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
247
386
|
save_workbook(wb, file_path)
|
|
248
387
|
|
|
249
388
|
|
|
250
|
-
def
|
|
251
|
-
if
|
|
252
|
-
return
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
389
|
+
def update_highlight_err_msg(result_df, highlight_dict):
|
|
390
|
+
if result_df.shape[1] <= 1:
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
if CompareConst.NPU_MD5 in result_df.columns:
|
|
394
|
+
return
|
|
395
|
+
|
|
396
|
+
err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
|
|
397
|
+
red_lines_num_set = highlight_dict.get('red_rows')
|
|
398
|
+
|
|
399
|
+
for color in ['red', 'yellow']:
|
|
400
|
+
line_key = f'{color}_lines'
|
|
401
|
+
lines = highlight_dict.get(line_key, [])
|
|
402
|
+
for line_index, messages in lines:
|
|
403
|
+
if color == 'yellow' and line_index in red_lines_num_set:
|
|
404
|
+
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
405
|
+
|
|
406
|
+
for msg in messages:
|
|
407
|
+
if err_msg[line_index] == '':
|
|
408
|
+
err_msg[line_index] = msg
|
|
409
|
+
else:
|
|
410
|
+
err_msg[line_index] += '\n' + msg
|
|
411
|
+
|
|
412
|
+
if color == 'red':
|
|
413
|
+
red_lines_num_set.add(line_index)
|
|
414
|
+
|
|
415
|
+
result_df[CompareConst.ERROR_MESSAGE] = err_msg
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.compare.layer_mapping.layer_mapping import (
|
|
17
|
+
generate_data_mapping_by_layer_mapping,
|
|
18
|
+
generate_api_mapping_by_layer_mapping,
|
|
19
|
+
)
|