mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
msprobe/core/compare/check.py
CHANGED
|
@@ -1,5 +1,22 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
from msprobe.core.common.log import logger
|
|
2
|
-
from msprobe.core.compare.utils import rename_api
|
|
17
|
+
from msprobe.core.compare.utils import rename_api
|
|
18
|
+
from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
|
|
19
|
+
from msprobe.core.common.const import Const
|
|
3
20
|
|
|
4
21
|
|
|
5
22
|
dtype_mapping = {
|
|
@@ -18,24 +35,28 @@ dtype_mapping = {
|
|
|
18
35
|
"BFloat16": "torch.bfloat16",
|
|
19
36
|
"Complex64": "torch.complex64",
|
|
20
37
|
"Complex128": "torch.complex128"
|
|
21
|
-
|
|
38
|
+
}
|
|
22
39
|
|
|
23
40
|
|
|
24
|
-
def check_struct_match(npu_dict, bench_dict
|
|
41
|
+
def check_struct_match(npu_dict, bench_dict):
|
|
25
42
|
npu_struct_in = npu_dict.get("input_struct")
|
|
26
43
|
bench_struct_in = bench_dict.get("input_struct")
|
|
27
44
|
npu_struct_out = npu_dict.get("output_struct")
|
|
28
45
|
bench_struct_out = bench_dict.get("output_struct")
|
|
29
46
|
|
|
30
|
-
if cross_frame:
|
|
31
|
-
npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
|
|
32
|
-
npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
|
|
33
47
|
is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
|
|
34
48
|
if not is_match:
|
|
35
49
|
if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
|
|
36
50
|
return False
|
|
37
|
-
|
|
38
|
-
|
|
51
|
+
try:
|
|
52
|
+
struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
|
|
53
|
+
struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
|
|
54
|
+
except CompareException as error:
|
|
55
|
+
err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
|
|
56
|
+
f'npu_dict: {npu_dict}' \
|
|
57
|
+
f'bench_dict: {bench_dict}'
|
|
58
|
+
logger.error(err_msg)
|
|
59
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
39
60
|
is_match = struct_in_is_match and struct_out_is_match
|
|
40
61
|
return is_match
|
|
41
62
|
|
|
@@ -43,17 +64,27 @@ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
|
|
|
43
64
|
def check_type_shape_match(npu_struct, bench_struct):
|
|
44
65
|
shape_type_match = False
|
|
45
66
|
for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
67
|
+
try:
|
|
68
|
+
npu_type = npu_type_shape[0]
|
|
69
|
+
npu_shape = npu_type_shape[1]
|
|
70
|
+
bench_type = bench_type_shape[0]
|
|
71
|
+
bench_shape = bench_type_shape[1]
|
|
72
|
+
except IndexError as error:
|
|
73
|
+
logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
|
|
74
|
+
f'should both be 2, please check!')
|
|
75
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
50
76
|
shape_match = npu_shape == bench_shape
|
|
51
77
|
type_match = npu_type == bench_type
|
|
52
78
|
if not type_match:
|
|
53
|
-
ms_type=
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
79
|
+
ms_type = [
|
|
80
|
+
[Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
|
|
81
|
+
[Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
|
|
82
|
+
]
|
|
83
|
+
torch_type = [
|
|
84
|
+
[Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
|
|
85
|
+
[Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
|
|
86
|
+
]
|
|
87
|
+
if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
|
|
57
88
|
type_match = True
|
|
58
89
|
else:
|
|
59
90
|
type_match = False
|
|
@@ -64,9 +95,9 @@ def check_type_shape_match(npu_struct, bench_struct):
|
|
|
64
95
|
|
|
65
96
|
|
|
66
97
|
def check_graph_mode(a_op_name, b_op_name):
|
|
67
|
-
if
|
|
98
|
+
if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
|
|
68
99
|
return True
|
|
69
|
-
if
|
|
100
|
+
if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
|
|
70
101
|
return True
|
|
71
102
|
return False
|
|
72
103
|
|
|
@@ -83,13 +114,64 @@ def fuzzy_check_op(npu_name_list, bench_name_list):
|
|
|
83
114
|
|
|
84
115
|
|
|
85
116
|
def fuzzy_check_name(npu_name, bench_name):
|
|
86
|
-
if
|
|
87
|
-
is_match = rename_api(npu_name,
|
|
88
|
-
elif
|
|
89
|
-
is_match = rename_api(npu_name,
|
|
117
|
+
if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
|
|
118
|
+
is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
|
|
119
|
+
elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
|
|
120
|
+
is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
|
|
90
121
|
else:
|
|
91
122
|
is_match = npu_name == bench_name
|
|
92
123
|
return is_match
|
|
93
124
|
|
|
94
125
|
|
|
126
|
+
def check_dump_json_str(op_data, op_name):
|
|
127
|
+
input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
|
|
128
|
+
Const.INPUT, None)
|
|
129
|
+
input_kwargs = op_data.get(Const.INPUT_KWARGS, None)
|
|
130
|
+
output_list = op_data.get(Const.OUTPUT, None)
|
|
131
|
+
|
|
132
|
+
args = [input_list, input_kwargs, output_list]
|
|
133
|
+
for arg in args:
|
|
134
|
+
if not arg:
|
|
135
|
+
continue
|
|
136
|
+
if isinstance(arg, dict):
|
|
137
|
+
check_json_key_value(arg, op_name)
|
|
138
|
+
else:
|
|
139
|
+
for ele in arg:
|
|
140
|
+
if not ele:
|
|
141
|
+
continue
|
|
142
|
+
check_json_key_value(ele, op_name)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def check_json_key_value(input_output, op_name, depth=0):
|
|
146
|
+
if depth > Const.MAX_DEPTH:
|
|
147
|
+
logger.error(f"string check of data info of {op_name} exceeds the recursion limit.")
|
|
148
|
+
return
|
|
149
|
+
if isinstance(input_output, list):
|
|
150
|
+
for item in input_output:
|
|
151
|
+
check_json_key_value(item, op_name, depth+1)
|
|
152
|
+
elif isinstance(input_output, dict):
|
|
153
|
+
for key, value in input_output.items():
|
|
154
|
+
if isinstance(value, dict):
|
|
155
|
+
check_json_key_value(value, op_name, depth+1)
|
|
156
|
+
else:
|
|
157
|
+
valid_key_value(key, value, op_name)
|
|
158
|
+
|
|
95
159
|
|
|
160
|
+
def valid_key_value(key, value, op_name):
|
|
161
|
+
if key == "shape" and not isinstance(value, (list, tuple)):
|
|
162
|
+
logger.error(f"shape of input or output of {op_name} is not list or tuple, please check!")
|
|
163
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
164
|
+
elif key == "requires_grad" and not isinstance(value, bool):
|
|
165
|
+
logger.error(f"requires_grad of input or output of {op_name} is not bool, please check!")
|
|
166
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
167
|
+
else:
|
|
168
|
+
check_op_str_pattern_valid(value, op_name)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def check_stack_json_str(stack_info, op_name):
|
|
172
|
+
if isinstance(stack_info, list):
|
|
173
|
+
for item in stack_info:
|
|
174
|
+
check_op_str_pattern_valid(item, op_name, stack=True)
|
|
175
|
+
else:
|
|
176
|
+
logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'")
|
|
177
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
@@ -1,15 +1,35 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import json
|
|
2
|
-
from msprobe.core.common.file_utils import
|
|
17
|
+
from msprobe.core.common.file_utils import check_file_type, load_json
|
|
3
18
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
4
19
|
from msprobe.core.common.utils import CompareException
|
|
5
20
|
from msprobe.core.common.log import logger
|
|
6
21
|
|
|
7
22
|
|
|
8
23
|
def compare_cli(args):
|
|
9
|
-
|
|
10
|
-
input_param = json.load(file)
|
|
24
|
+
input_param = load_json(args.input_path)
|
|
11
25
|
npu_path = input_param.get("npu_path", None)
|
|
12
26
|
bench_path = input_param.get("bench_path", None)
|
|
27
|
+
if not npu_path:
|
|
28
|
+
logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!")
|
|
29
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
30
|
+
if not bench_path:
|
|
31
|
+
logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!")
|
|
32
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
13
33
|
frame_name = args.framework
|
|
14
34
|
auto_analyze = not args.compare_only
|
|
15
35
|
if frame_name == Const.PT_FRAMEWORK:
|
|
@@ -19,12 +39,18 @@ def compare_cli(args):
|
|
|
19
39
|
from msprobe.mindspore.compare.ms_compare import ms_compare
|
|
20
40
|
from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
|
|
21
41
|
if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
|
|
42
|
+
if "stack_path" not in input_param:
|
|
43
|
+
logger.error(f"Missing stack_path in configuration file {args.input_path}, please check!")
|
|
44
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
22
45
|
input_param["npu_json_path"] = input_param.pop("npu_path")
|
|
23
46
|
input_param["bench_json_path"] = input_param.pop("bench_path")
|
|
24
47
|
input_param["stack_json_path"] = input_param.pop("stack_path")
|
|
25
48
|
if frame_name == Const.PT_FRAMEWORK:
|
|
49
|
+
kwargs = {
|
|
50
|
+
"data_mapping": args.data_mapping
|
|
51
|
+
}
|
|
26
52
|
compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
|
|
27
|
-
fuzzy_match=args.fuzzy_match)
|
|
53
|
+
fuzzy_match=args.fuzzy_match, **kwargs)
|
|
28
54
|
else:
|
|
29
55
|
kwargs = {
|
|
30
56
|
"stack_mode": args.stack_mode,
|
|
@@ -32,11 +58,22 @@ def compare_cli(args):
|
|
|
32
58
|
"fuzzy_match": args.fuzzy_match,
|
|
33
59
|
"cell_mapping": args.cell_mapping,
|
|
34
60
|
"api_mapping": args.api_mapping,
|
|
61
|
+
"data_mapping": args.data_mapping,
|
|
62
|
+
"layer_mapping": args.layer_mapping
|
|
35
63
|
}
|
|
36
64
|
|
|
37
65
|
ms_compare(input_param, args.output_path, **kwargs)
|
|
38
66
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
39
|
-
kwargs = {
|
|
67
|
+
kwargs = {
|
|
68
|
+
"stack_mode": args.stack_mode,
|
|
69
|
+
"auto_analyze": auto_analyze,
|
|
70
|
+
"fuzzy_match": args.fuzzy_match,
|
|
71
|
+
"is_print_compare_log": input_param.get("is_print_compare_log", True),
|
|
72
|
+
"cell_mapping": args.cell_mapping,
|
|
73
|
+
"api_mapping": args.api_mapping,
|
|
74
|
+
"data_mapping": args.data_mapping,
|
|
75
|
+
"layer_mapping": args.layer_mapping
|
|
76
|
+
}
|
|
40
77
|
if input_param.get("rank_id") is not None:
|
|
41
78
|
ms_graph_compare(input_param, args.output_path)
|
|
42
79
|
return
|
|
@@ -1,89 +1,127 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import math
|
|
2
17
|
import abc
|
|
18
|
+
import re
|
|
3
19
|
from collections import namedtuple
|
|
4
20
|
import numpy as np
|
|
5
21
|
import openpyxl
|
|
6
22
|
from openpyxl.styles import PatternFill
|
|
23
|
+
from tqdm import tqdm
|
|
7
24
|
from msprobe.core.common.utils import get_header_index
|
|
8
25
|
from msprobe.core.common.file_utils import save_workbook
|
|
9
26
|
from msprobe.core.common.log import logger
|
|
10
|
-
from msprobe.core.common.const import CompareConst
|
|
27
|
+
from msprobe.core.common.const import CompareConst, FileCheckConst, Const
|
|
28
|
+
from msprobe.core.common.utils import safe_get_value
|
|
11
29
|
|
|
12
30
|
|
|
13
31
|
class HighlightCheck(abc.ABC):
|
|
14
32
|
@abc.abstractmethod
|
|
15
|
-
def apply(self, info, color_columns,
|
|
33
|
+
def apply(self, info, color_columns, dump_mode):
|
|
16
34
|
raise NotImplementedError
|
|
17
35
|
|
|
18
36
|
|
|
37
|
+
def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
38
|
+
for i, (existing_num, existing_err_msg) in enumerate(color_list):
|
|
39
|
+
if num == existing_num:
|
|
40
|
+
color_list[i][1].append(highlight_err_msg)
|
|
41
|
+
return
|
|
42
|
+
color_list.append((num, [highlight_err_msg]))
|
|
43
|
+
|
|
44
|
+
|
|
19
45
|
class CheckOrderMagnitude(HighlightCheck):
|
|
20
46
|
"""检查Max diff的数量级差异"""
|
|
21
|
-
def apply(self, info, color_columns,
|
|
47
|
+
def apply(self, info, color_columns, dump_mode):
|
|
22
48
|
api_in, api_out, num = info
|
|
23
|
-
max_diff_index = get_header_index(
|
|
49
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
50
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
24
51
|
if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
|
|
25
52
|
return
|
|
26
53
|
in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
|
|
27
54
|
out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
|
|
28
55
|
if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
|
|
29
|
-
color_columns.yellow
|
|
56
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
57
|
+
"maximum absolute error of both input and output exceed 1, "
|
|
58
|
+
"with the output larger by an order of magnitude")
|
|
30
59
|
|
|
31
60
|
|
|
32
61
|
class CheckOneThousandErrorRatio(HighlightCheck):
|
|
33
62
|
"""检查千分误差比率"""
|
|
34
|
-
def apply(self, info, color_columns,
|
|
63
|
+
def apply(self, info, color_columns, dump_mode):
|
|
35
64
|
api_in, api_out, num = info
|
|
36
|
-
one_thousand_index = get_header_index(
|
|
37
|
-
if not isinstance(api_in[one_thousand_index], (float, int)) or
|
|
65
|
+
one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
|
|
66
|
+
if (not isinstance(api_in[one_thousand_index], (float, int)) or
|
|
67
|
+
not isinstance(api_out[one_thousand_index], (float, int))):
|
|
38
68
|
return
|
|
39
|
-
if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
40
|
-
|
|
69
|
+
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
70
|
+
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
71
|
+
add_highlight_row_info(color_columns.red, num,
|
|
72
|
+
"The input's one thousandth err ratio exceeds 0.9, while the output's is below 0.6")
|
|
41
73
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
42
|
-
color_columns.yellow
|
|
74
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
75
|
+
"The output's one thousandth err ratio decreases by more than 0.1 "
|
|
76
|
+
"compared to the input's")
|
|
43
77
|
|
|
44
78
|
|
|
45
79
|
class CheckCosineSimilarity(HighlightCheck):
|
|
46
80
|
"""检查余弦相似度"""
|
|
47
|
-
def apply(self, info, color_columns,
|
|
81
|
+
def apply(self, info, color_columns, dump_mode):
|
|
48
82
|
api_in, api_out, num = info
|
|
49
|
-
cosine_index = get_header_index(
|
|
83
|
+
cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
|
|
50
84
|
if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
|
|
51
85
|
return
|
|
52
86
|
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
53
|
-
color_columns.yellow
|
|
87
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
88
|
+
"The output's cosine decreases by more than 0.1 compared to the input's")
|
|
54
89
|
|
|
55
90
|
|
|
56
91
|
class CheckMaxRelativeDiff(HighlightCheck):
|
|
57
92
|
"""检查最大相对差异"""
|
|
58
|
-
def apply(self, info, color_columns,
|
|
93
|
+
def apply(self, info, color_columns, dump_mode):
|
|
59
94
|
api_in, api_out, num = info
|
|
60
|
-
max_diff_index = get_header_index(
|
|
61
|
-
bench_max_index = get_header_index(
|
|
95
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
|
|
96
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
62
97
|
input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
|
|
63
98
|
output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
|
|
64
99
|
if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
|
|
65
100
|
(float, int)):
|
|
66
101
|
return
|
|
67
102
|
if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
|
|
68
|
-
color_columns.red.
|
|
69
|
-
elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
70
|
-
|
|
103
|
+
add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
|
|
104
|
+
elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
105
|
+
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
106
|
+
add_highlight_row_info(color_columns.yellow, num,
|
|
107
|
+
"The output's maximum relative error exceeds 0.1, while the input's is below 0.01")
|
|
71
108
|
|
|
72
109
|
|
|
73
110
|
class CheckOverflow(HighlightCheck):
|
|
74
111
|
"""检查是否存在溢出"""
|
|
75
|
-
def apply(self, info, color_columns,
|
|
112
|
+
def apply(self, info, color_columns, dump_mode):
|
|
76
113
|
line, num = info
|
|
77
|
-
npu_max_index = get_header_index(
|
|
78
|
-
npu_min_index = get_header_index(
|
|
79
|
-
max_diff_index = get_header_index(
|
|
114
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
115
|
+
npu_min_index = get_header_index(CompareConst.NPU_MIN, dump_mode)
|
|
116
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
117
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
80
118
|
if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
|
|
81
119
|
line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
|
|
82
|
-
color_columns.red
|
|
120
|
+
add_highlight_row_info(color_columns.red, num, "maximum or minimum is nan, -inf, or inf")
|
|
83
121
|
return
|
|
84
122
|
# check if Max_Diff > 1e+10
|
|
85
|
-
if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
|
|
86
|
-
color_columns.red
|
|
123
|
+
if isinstance(line[max_diff_index], (float, int)) and abs(line[max_diff_index]) > CompareConst.MAX_DIFF_RED:
|
|
124
|
+
add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
|
|
87
125
|
|
|
88
126
|
|
|
89
127
|
class HighlightRules:
|
|
@@ -105,13 +143,14 @@ class HighlightRules:
|
|
|
105
143
|
}
|
|
106
144
|
|
|
107
145
|
|
|
108
|
-
def find_error_rows(result, last_len, n_num_input, highlight_dict,
|
|
146
|
+
def find_error_rows(result, last_len, n_num_input, highlight_dict, dump_mode):
|
|
109
147
|
"""找到单个API中需要高亮的行"""
|
|
110
|
-
if
|
|
148
|
+
if dump_mode == Const.MD5:
|
|
111
149
|
return
|
|
112
|
-
npu_max_index = get_header_index(
|
|
113
|
-
bench_max_index = get_header_index(
|
|
114
|
-
max_diff_index = get_header_index(
|
|
150
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
151
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
152
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
153
|
+
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
115
154
|
|
|
116
155
|
red_lines, yellow_lines = [], []
|
|
117
156
|
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
@@ -124,7 +163,7 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
|
|
|
124
163
|
num = last_len + i
|
|
125
164
|
line_info = LineInfo(line_data=line, num_pointer=num)
|
|
126
165
|
for rule in HighlightRules.basic_rules.values():
|
|
127
|
-
rule.apply(line_info, color_columns,
|
|
166
|
+
rule.apply(line_info, color_columns, dump_mode)
|
|
128
167
|
|
|
129
168
|
# 对API的输出与输入比较,进行误差判断
|
|
130
169
|
for n, api_out in enumerate(result[n_num_input:len(result)]):
|
|
@@ -142,36 +181,42 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
|
|
|
142
181
|
continue
|
|
143
182
|
|
|
144
183
|
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
|
|
145
|
-
if
|
|
184
|
+
if dump_mode == Const.SUMMARY:
|
|
146
185
|
for rule in HighlightRules.summary_compare_rules.values():
|
|
147
|
-
rule.apply(api_info, color_columns,
|
|
186
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
148
187
|
else:
|
|
149
188
|
for rule in HighlightRules.compare_rules.values():
|
|
150
|
-
rule.apply(api_info, color_columns,
|
|
189
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
151
190
|
|
|
152
|
-
|
|
153
|
-
|
|
191
|
+
red_lines_num_set = {x[0] for x in red_lines}
|
|
192
|
+
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
193
|
+
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
194
|
+
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
195
|
+
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
196
|
+
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
154
197
|
|
|
155
198
|
|
|
156
199
|
def get_name_and_state(name):
|
|
157
200
|
"""Get api/module name and state"""
|
|
158
|
-
if
|
|
159
|
-
api_name = name.split(
|
|
160
|
-
state =
|
|
201
|
+
if Const.INPUT in name:
|
|
202
|
+
api_name = name.split(Const.INPUT)[0]
|
|
203
|
+
state = Const.INPUT
|
|
161
204
|
else:
|
|
162
|
-
api_name = name.split(
|
|
163
|
-
state =
|
|
205
|
+
api_name = name.split(Const.OUTPUT)[0]
|
|
206
|
+
state = Const.OUTPUT
|
|
164
207
|
return api_name, state
|
|
165
208
|
|
|
166
209
|
|
|
167
|
-
def find_compare_result_error_rows(result_df, highlight_dict,
|
|
210
|
+
def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
|
|
168
211
|
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
169
212
|
result = result_df.values
|
|
170
213
|
start, input_num, output_num, end = 0, 0, 0, len(result_df)
|
|
171
214
|
last_api_name, last_state = None, None
|
|
172
215
|
num, last_len = 0, 0
|
|
216
|
+
progress_bar = tqdm(total=len(result), desc="API/Module Analyse Progress", unit="item", ncols=100)
|
|
173
217
|
for res_i in result:
|
|
174
|
-
|
|
218
|
+
api_full_name = safe_get_value(res_i, 0, "res_i")
|
|
219
|
+
api_name, state = get_name_and_state(api_full_name)
|
|
175
220
|
if last_api_name:
|
|
176
221
|
if api_name == last_api_name:
|
|
177
222
|
if state == last_state:
|
|
@@ -182,42 +227,102 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m
|
|
|
182
227
|
else:
|
|
183
228
|
output_num = num
|
|
184
229
|
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
185
|
-
|
|
230
|
+
dump_mode)
|
|
186
231
|
num, last_api_name, last_state = 1, api_name, state
|
|
187
232
|
start += input_num + output_num
|
|
188
233
|
input_num, output_num = 1, 0
|
|
189
234
|
else:
|
|
190
235
|
num, last_api_name, last_state = 1, api_name, state
|
|
236
|
+
progress_bar.update(1)
|
|
237
|
+
progress_bar.close()
|
|
191
238
|
if state:
|
|
192
|
-
if state ==
|
|
239
|
+
if state == Const.INPUT:
|
|
193
240
|
input_num = num
|
|
194
241
|
else:
|
|
195
242
|
output_num = num
|
|
196
|
-
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
243
|
+
find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
|
|
244
|
+
dump_mode)
|
|
197
245
|
|
|
198
246
|
|
|
199
247
|
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
200
248
|
"""Write and highlight results in Excel"""
|
|
201
|
-
|
|
249
|
+
|
|
250
|
+
update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
202
251
|
|
|
203
252
|
wb = openpyxl.Workbook()
|
|
204
253
|
ws = wb.active
|
|
205
254
|
|
|
206
255
|
# write header
|
|
256
|
+
logger.info('Initializing Excel file.')
|
|
207
257
|
for j, col_name in enumerate(result_df.columns, start=1):
|
|
258
|
+
if not csv_value_is_valid(col_name):
|
|
259
|
+
raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
|
|
208
260
|
ws.cell(row=1, column=j, value=col_name)
|
|
209
261
|
|
|
210
262
|
for i, row in enumerate(result_df.iterrows(), start=2):
|
|
211
263
|
for j, value in enumerate(row[1], start=1):
|
|
212
|
-
if not isinstance(value, (float, int)):
|
|
264
|
+
if not isinstance(value, (float, int)) or isinstance(value, bool):
|
|
213
265
|
value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
|
|
266
|
+
if not csv_value_is_valid(value):
|
|
267
|
+
raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: "
|
|
268
|
+
f"{file_path}.")
|
|
214
269
|
ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
|
|
270
|
+
|
|
271
|
+
# 对可疑数据标色
|
|
272
|
+
logger.info('Coloring Excel in progress.')
|
|
273
|
+
col_len = len(result_df.columns)
|
|
274
|
+
red_fill = PatternFill(
|
|
275
|
+
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
276
|
+
)
|
|
277
|
+
yellow_fill = PatternFill(
|
|
278
|
+
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
279
|
+
)
|
|
280
|
+
for i in highlight_dict.get("red_rows", []):
|
|
281
|
+
for j in range(1, col_len + 1):
|
|
282
|
+
ws.cell(row=i + 2, column=j).fill = red_fill
|
|
283
|
+
for i in highlight_dict.get("yellow_rows", []):
|
|
284
|
+
for j in range(1, col_len + 1):
|
|
285
|
+
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
286
|
+
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
287
|
+
save_workbook(wb, file_path)
|
|
215
288
|
|
|
216
|
-
if (i - 2) in highlight_dict['red_rows']:
|
|
217
|
-
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
|
|
218
|
-
end_color=CompareConst.RED, fill_type="solid")
|
|
219
|
-
elif (i - 2) in highlight_dict['yellow_rows']:
|
|
220
|
-
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
|
|
221
|
-
end_color=CompareConst.YELLOW, fill_type="solid")
|
|
222
289
|
|
|
223
|
-
|
|
290
|
+
def update_highlight_err_msg(result_df, highlight_dict):
|
|
291
|
+
if result_df.shape[1] <= 1:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
if CompareConst.NPU_MD5 in result_df.columns:
|
|
295
|
+
return
|
|
296
|
+
|
|
297
|
+
err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
|
|
298
|
+
red_lines_num_set = highlight_dict.get('red_rows')
|
|
299
|
+
|
|
300
|
+
for color in ['red', 'yellow']:
|
|
301
|
+
line_key = f'{color}_lines'
|
|
302
|
+
lines = highlight_dict.get(line_key, [])
|
|
303
|
+
for line_index, messages in lines:
|
|
304
|
+
if color == 'yellow' and line_index in red_lines_num_set:
|
|
305
|
+
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
306
|
+
|
|
307
|
+
for msg in messages:
|
|
308
|
+
if err_msg[line_index] == '':
|
|
309
|
+
err_msg[line_index] = msg
|
|
310
|
+
else:
|
|
311
|
+
err_msg[line_index] += '\n' + msg
|
|
312
|
+
|
|
313
|
+
if color == 'red':
|
|
314
|
+
red_lines_num_set.add(line_index)
|
|
315
|
+
|
|
316
|
+
result_df[CompareConst.ERROR_MESSAGE] = err_msg
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def csv_value_is_valid(value: str) -> bool:
|
|
320
|
+
if not isinstance(value, str):
|
|
321
|
+
return True
|
|
322
|
+
try:
|
|
323
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
324
|
+
float(value)
|
|
325
|
+
except ValueError:
|
|
326
|
+
# otherwise, they will be considered as formular injections
|
|
327
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
328
|
+
return True
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from msprobe.core.compare.layer_mapping.layer_mapping import (
|
|
17
|
+
generate_data_mapping_by_layer_mapping,
|
|
18
|
+
generate_api_mapping_by_layer_mapping,
|
|
19
|
+
)
|