mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,36 +15,46 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
|
+
import math
|
|
19
|
+
import zlib
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
18
22
|
import numpy as np
|
|
19
|
-
|
|
20
|
-
from msprobe.core.common.
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.const import Const, CompareConst, FileCheckConst
|
|
25
|
+
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
21
26
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
22
27
|
|
|
23
28
|
|
|
24
29
|
def extract_json(dirname, stack_json=False):
|
|
25
30
|
json_path = ''
|
|
26
|
-
for
|
|
27
|
-
if
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
json_path = full_path
|
|
32
|
-
if not stack_json and 'stack' not in json_path:
|
|
33
|
-
break
|
|
34
|
-
if stack_json and 'stack' in json_path:
|
|
35
|
-
break
|
|
31
|
+
for filename in os.listdir(dirname):
|
|
32
|
+
target_file_name = 'stack.json' if stack_json else 'dump.json'
|
|
33
|
+
if filename == target_file_name:
|
|
34
|
+
json_path = os.path.join(dirname, filename)
|
|
35
|
+
break
|
|
36
36
|
|
|
37
37
|
# Provide robustness on invalid directory inputs
|
|
38
38
|
if not json_path:
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
if stack_json:
|
|
40
|
+
logger.warning(f'stack.json is not found in dump dir {dirname}.')
|
|
41
|
+
else:
|
|
42
|
+
logger.error(f'dump.json is not found in dump dir {dirname}.')
|
|
43
|
+
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
41
44
|
return json_path
|
|
42
45
|
|
|
43
46
|
|
|
47
|
+
def set_stack_json_path(input_param):
|
|
48
|
+
npu_data_dir = os.path.dirname(input_param.get("npu_json_path"))
|
|
49
|
+
stack_path = extract_json(npu_data_dir, stack_json=True)
|
|
50
|
+
input_param["stack_json_path"] = stack_path if stack_path else None
|
|
51
|
+
return bool(stack_path)
|
|
52
|
+
|
|
53
|
+
|
|
44
54
|
def check_and_return_dir_contents(dump_dir, prefix):
|
|
45
55
|
"""
|
|
46
56
|
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
47
|
-
pattern: ^{prefix}(?:0|[
|
|
57
|
+
pattern: ^{prefix}(?:0|[1-9][0-9]*)?$
|
|
48
58
|
|
|
49
59
|
Args:
|
|
50
60
|
dump_dir (str): dump dir
|
|
@@ -60,7 +70,7 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
60
70
|
check_regex_prefix_format_valid(prefix)
|
|
61
71
|
check_file_or_directory_path(dump_dir, True)
|
|
62
72
|
contents = os.listdir(dump_dir)
|
|
63
|
-
pattern = re.compile(rf'^{prefix}(?:0|[
|
|
73
|
+
pattern = re.compile(rf'^{prefix}(?:0|[1-9][0-9]*)?$')
|
|
64
74
|
for name in contents:
|
|
65
75
|
if not pattern.match(name):
|
|
66
76
|
logger.error(
|
|
@@ -72,6 +82,10 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
72
82
|
|
|
73
83
|
|
|
74
84
|
def rename_api(npu_name, process):
|
|
85
|
+
"""
|
|
86
|
+
原api: {api_type}.{api_name}.{API调用次数}.{前向反向}.{input/output}.{参数序号}
|
|
87
|
+
rename后: {api_type}.{api_name}.{input/output}.{参数序号}
|
|
88
|
+
"""
|
|
75
89
|
npu_split = npu_name.split(process)
|
|
76
90
|
try:
|
|
77
91
|
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
@@ -84,122 +98,89 @@ def rename_api(npu_name, process):
|
|
|
84
98
|
|
|
85
99
|
|
|
86
100
|
def read_op(op_data, op_name):
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
if Const.INPUT_KWARGS in op_data:
|
|
95
|
-
kwargs_item = op_data[Const.INPUT_KWARGS]
|
|
96
|
-
if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list):
|
|
97
|
-
kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None)
|
|
98
|
-
op_parsed_list += kwarg_parsed_list
|
|
99
|
-
kwarg_parsed_list.clear()
|
|
100
|
-
elif kwargs_item:
|
|
101
|
-
for kwarg in kwargs_item:
|
|
102
|
-
kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
|
|
103
|
-
op_parsed_list += kwarg_parsed_list
|
|
104
|
-
kwarg_parsed_list.clear()
|
|
105
|
-
if Const.OUTPUT in op_data:
|
|
106
|
-
output_item = op_data[Const.OUTPUT]
|
|
107
|
-
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
108
|
-
op_parsed_list += output_parsed_list
|
|
109
|
-
output_parsed_list.clear()
|
|
110
|
-
if Const.BACKWARD in op_name:
|
|
111
|
-
if Const.INPUT in op_data:
|
|
112
|
-
input_item = op_data[Const.INPUT]
|
|
113
|
-
input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
|
|
114
|
-
op_parsed_list = input_parsed_list.copy()
|
|
115
|
-
input_parsed_list.clear()
|
|
116
|
-
if Const.OUTPUT in op_data:
|
|
117
|
-
output_item = op_data[Const.OUTPUT]
|
|
118
|
-
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
119
|
-
op_parsed_list += output_parsed_list
|
|
120
|
-
output_parsed_list.clear()
|
|
101
|
+
if Const.PARAMS_GRAD in op_name.split(Const.SEP):
|
|
102
|
+
op_parsed_list = op_item_parse(op_data, op_name)
|
|
103
|
+
else:
|
|
104
|
+
op_parsed_list = []
|
|
105
|
+
for name in CompareConst.IO_NAME_MAPPING:
|
|
106
|
+
if name in op_data:
|
|
107
|
+
op_parsed_list.extend(op_item_parse(op_data[name], op_name + CompareConst.IO_NAME_MAPPING[name]))
|
|
121
108
|
return op_parsed_list
|
|
122
109
|
|
|
123
110
|
|
|
124
|
-
def op_item_parse(
|
|
111
|
+
def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
112
|
+
default_item = {
|
|
113
|
+
'full_op_name': op_name,
|
|
114
|
+
'type': None,
|
|
115
|
+
'Max': None,
|
|
116
|
+
'Min': None,
|
|
117
|
+
'Mean': None,
|
|
118
|
+
'Norm': None,
|
|
119
|
+
'dtype': None,
|
|
120
|
+
'shape': None,
|
|
121
|
+
'md5': None,
|
|
122
|
+
'value': None,
|
|
123
|
+
'data_name': '-1'
|
|
124
|
+
}
|
|
125
|
+
|
|
125
126
|
if depth > Const.MAX_DEPTH:
|
|
126
|
-
logger.error(f
|
|
127
|
+
logger.error(f'parse of api/module of {op_name} exceeds the recursion limit.')
|
|
127
128
|
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
'shape': None, 'md5': None, 'data_name': '-1'
|
|
140
|
-
}
|
|
141
|
-
item_list.append(tmp)
|
|
142
|
-
return item_list
|
|
143
|
-
if index is None:
|
|
144
|
-
if isinstance(item, dict):
|
|
145
|
-
full_op_name = op_name + '.0'
|
|
146
|
-
else:
|
|
147
|
-
full_op_name = op_name
|
|
148
|
-
else:
|
|
149
|
-
full_op_name = op_name + Const.SEP + str(index)
|
|
150
|
-
if isinstance(item, dict):
|
|
151
|
-
if 'type' not in item:
|
|
152
|
-
for kwarg in item:
|
|
153
|
-
kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None, depth=depth+1)
|
|
154
|
-
item_list += kwarg_parsed_list
|
|
155
|
-
kwarg_parsed_list.clear()
|
|
156
|
-
elif 'dtype' in item:
|
|
157
|
-
parsed_item = item
|
|
158
|
-
parsed_item['full_op_name'] = full_op_name
|
|
159
|
-
item_list.append(parsed_item)
|
|
160
|
-
elif 'type' in item:
|
|
161
|
-
parsed_item = {}
|
|
162
|
-
if item['type'] == 'torch.Size':
|
|
163
|
-
parsed_item['full_op_name'] = full_op_name
|
|
164
|
-
parsed_item['dtype'] = 'torch.Size'
|
|
165
|
-
parsed_item['shape'] = str(item['value'])
|
|
166
|
-
parsed_item['md5'] = None
|
|
167
|
-
parsed_item['Max'] = None
|
|
168
|
-
parsed_item['Min'] = None
|
|
169
|
-
parsed_item['Mean'] = None
|
|
170
|
-
parsed_item['Norm'] = None
|
|
171
|
-
parsed_item['data_name'] = '-1'
|
|
172
|
-
item_list.append(parsed_item)
|
|
173
|
-
elif item['type'] == 'slice':
|
|
174
|
-
parsed_item['full_op_name'] = full_op_name
|
|
175
|
-
parsed_item['dtype'] = 'slice'
|
|
176
|
-
parsed_item['shape'] = str(np.shape(np.array(item['value'])))
|
|
177
|
-
parsed_item['md5'] = None
|
|
178
|
-
parsed_item['Max'] = None
|
|
179
|
-
parsed_item['Min'] = None
|
|
180
|
-
parsed_item['Mean'] = None
|
|
181
|
-
parsed_item['Norm'] = None
|
|
182
|
-
parsed_item['data_name'] = '-1'
|
|
183
|
-
item_list.append(parsed_item)
|
|
129
|
+
|
|
130
|
+
if op_data is None:
|
|
131
|
+
return [default_item]
|
|
132
|
+
elif not op_data:
|
|
133
|
+
return []
|
|
134
|
+
|
|
135
|
+
item_list = []
|
|
136
|
+
if isinstance(op_data, list):
|
|
137
|
+
for i, data in enumerate(op_data):
|
|
138
|
+
if Const.PARAMS_GRAD not in op_name.split(Const.SEP):
|
|
139
|
+
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
|
|
184
140
|
else:
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
parsed_item['Mean'] = item['value']
|
|
192
|
-
parsed_item['Norm'] = item['value']
|
|
193
|
-
parsed_item['data_name'] = '-1'
|
|
194
|
-
item_list.append(parsed_item)
|
|
195
|
-
else:
|
|
196
|
-
resolve_api_special_parameters(item, full_op_name, item_list)
|
|
197
|
-
else:
|
|
198
|
-
for j, item_spec in enumerate(item):
|
|
199
|
-
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False, depth=depth+1)
|
|
141
|
+
item_list.extend(op_item_parse(data, op_name, depth + 1))
|
|
142
|
+
elif isinstance(op_data, dict):
|
|
143
|
+
if is_leaf_data(op_data):
|
|
144
|
+
return [gen_op_item(op_data, op_name)]
|
|
145
|
+
for sub_name, sub_data in op_data.items():
|
|
146
|
+
item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
|
|
200
147
|
return item_list
|
|
201
148
|
|
|
202
149
|
|
|
150
|
+
def is_leaf_data(op_data):
|
|
151
|
+
return 'type' in op_data and isinstance(op_data['type'], str)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def gen_op_item(op_data, op_name):
|
|
155
|
+
op_item = {}
|
|
156
|
+
op_item.update(op_data)
|
|
157
|
+
data_name = op_data.get('data_name') if op_data.get('data_name') else '-1' # 如果是""也返回-1
|
|
158
|
+
op_item['data_name'] = data_name
|
|
159
|
+
op_item['full_op_name'] = data_name.rsplit(Const.SEP, 1)[0] if data_name != '-1' else op_name
|
|
160
|
+
|
|
161
|
+
params = ['Max', 'Min', 'Mean', 'Norm']
|
|
162
|
+
for i in params:
|
|
163
|
+
if i not in op_item:
|
|
164
|
+
op_item[i] = None
|
|
165
|
+
|
|
166
|
+
if not op_item.get('dtype'):
|
|
167
|
+
if op_item.get('type') == 'torch.Size':
|
|
168
|
+
op_item['dtype'] = op_data.get('type')
|
|
169
|
+
op_item['shape'] = str(op_data.get('value'))
|
|
170
|
+
elif op_item.get('type') == 'slice':
|
|
171
|
+
op_item['dtype'] = op_data.get('type')
|
|
172
|
+
op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
|
|
173
|
+
else:
|
|
174
|
+
op_item['dtype'] = str(type(op_data.get('value')))
|
|
175
|
+
op_item['shape'] = '[]'
|
|
176
|
+
for i in params:
|
|
177
|
+
op_item[i] = op_data.get('value')
|
|
178
|
+
if not op_item.get('md5'):
|
|
179
|
+
op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
|
|
180
|
+
|
|
181
|
+
return op_item
|
|
182
|
+
|
|
183
|
+
|
|
203
184
|
def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
204
185
|
"""
|
|
205
186
|
Function Description:
|
|
@@ -231,223 +212,387 @@ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
|
231
212
|
item_list.append(parsed_item)
|
|
232
213
|
|
|
233
214
|
|
|
234
|
-
def
|
|
215
|
+
def process_summary_data(summary_data):
|
|
216
|
+
"""处理summary_data中的nan值,返回处理后的列表"""
|
|
217
|
+
return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data]
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg):
|
|
221
|
+
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
222
|
+
warning_flag = False
|
|
223
|
+
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
224
|
+
if all(isinstance(val, (float, int)) and not isinstance(val, bool) for val in [npu_val, bench_val]):
|
|
225
|
+
diff = npu_val - bench_val
|
|
226
|
+
if math.isnan(diff):
|
|
227
|
+
diff = CompareConst.NAN
|
|
228
|
+
relative = CompareConst.NAN
|
|
229
|
+
else:
|
|
230
|
+
if bench_val != 0:
|
|
231
|
+
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
232
|
+
else:
|
|
233
|
+
relative = CompareConst.N_A
|
|
234
|
+
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + CompareConst.EPSILON)
|
|
235
|
+
if magnitude_diff > CompareConst.MAGNITUDE:
|
|
236
|
+
warning_flag = True
|
|
237
|
+
result_item[start_idx + i] = diff
|
|
238
|
+
result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = relative
|
|
239
|
+
else:
|
|
240
|
+
result_item[start_idx + i] = CompareConst.N_A
|
|
241
|
+
result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = CompareConst.N_A
|
|
242
|
+
|
|
243
|
+
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
244
|
+
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
245
|
+
for i in range(start_idx, len(result_item)):
|
|
246
|
+
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
247
|
+
result_item[i] = f'{result_item[i]}\t'
|
|
248
|
+
return result_item, accuracy_check, err_msg
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@dataclass
|
|
252
|
+
class ApiItemInfo:
|
|
253
|
+
name: str
|
|
254
|
+
struct: tuple
|
|
255
|
+
stack_info: list
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
|
|
259
|
+
if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT:
|
|
260
|
+
result_item.extend(npu_stack_info)
|
|
261
|
+
else:
|
|
262
|
+
result_item.append(CompareConst.NONE)
|
|
263
|
+
return result_item
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def result_item_init(n_info, b_info, dump_mode):
|
|
267
|
+
n_len = len(n_info.struct)
|
|
268
|
+
b_len = len(b_info.struct)
|
|
269
|
+
struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
|
|
270
|
+
if struct_long_enough:
|
|
271
|
+
result_item = [
|
|
272
|
+
n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
|
|
273
|
+
]
|
|
274
|
+
if dump_mode == Const.MD5:
|
|
275
|
+
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
276
|
+
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
277
|
+
elif dump_mode == Const.SUMMARY:
|
|
278
|
+
result_item.extend([" "] * 8)
|
|
279
|
+
else:
|
|
280
|
+
result_item.extend([" "] * 5)
|
|
281
|
+
else:
|
|
282
|
+
err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
|
|
283
|
+
f"npu_info_struct is {n_info.struct}\n" \
|
|
284
|
+
f"bench_info_struct is {b_info.struct}"
|
|
285
|
+
logger.error(err_msg)
|
|
286
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
287
|
+
return result_item
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def count_struct(op_dict):
|
|
291
|
+
parts = [
|
|
292
|
+
CompareConst.OP_NAME,
|
|
293
|
+
CompareConst.INPUT_STRUCT,
|
|
294
|
+
CompareConst.OUTPUT_STRUCT,
|
|
295
|
+
CompareConst.PARAMS_STRUCT,
|
|
296
|
+
CompareConst.PARAMS_GRAD_STRUCT
|
|
297
|
+
]
|
|
298
|
+
lengths = [len(op_dict.get(part, [])) for part in parts]
|
|
299
|
+
num = lengths[0]
|
|
300
|
+
if num != sum(lengths[1:]):
|
|
301
|
+
logger.error(f"Length of names and structs of op_dict not match. Please check! op_dict: {op_dict}")
|
|
302
|
+
raise CompareException(CompareException.NAMES_STRUCTS_MATCH_ERROR)
|
|
303
|
+
return tuple(lengths)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
235
307
|
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
236
308
|
min_len = min(n_len, b_len)
|
|
237
309
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
238
310
|
bench_stack_info = b_dict.get("stack_info", None)
|
|
239
311
|
has_stack = npu_stack_info and bench_stack_info
|
|
240
312
|
|
|
241
|
-
|
|
242
|
-
if all_mode_bool:
|
|
313
|
+
if dump_mode == Const.ALL:
|
|
243
314
|
npu_data_name = n_dict.get("data_name", None)
|
|
244
315
|
bench_data_name = b_dict.get("data_name", None)
|
|
245
316
|
|
|
246
317
|
for index in range(min_len):
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
b_struct = b_dict[key][index]
|
|
318
|
+
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
319
|
+
b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
|
|
320
|
+
n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
|
|
321
|
+
b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
|
|
252
322
|
err_msg = ""
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
else:
|
|
261
|
-
result_item.append(CompareConst.NONE)
|
|
323
|
+
|
|
324
|
+
npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
|
|
325
|
+
bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
|
|
326
|
+
result_item = result_item_init(npu_info, bench_info, dump_mode)
|
|
327
|
+
|
|
328
|
+
if dump_mode == Const.MD5:
|
|
329
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
262
330
|
result.append(result_item)
|
|
263
331
|
continue
|
|
264
332
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
result_item =
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
npu_summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
277
|
-
result_item.extend(npu_summary_data)
|
|
278
|
-
bench_summary_data = b_dict.get(CompareConst.SUMMARY)[b_start + index]
|
|
279
|
-
result_item.extend(bench_summary_data)
|
|
280
|
-
|
|
281
|
-
if summary_compare:
|
|
282
|
-
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
283
|
-
warning_flag = False
|
|
284
|
-
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
285
|
-
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
286
|
-
diff = npu_val - bench_val
|
|
287
|
-
if bench_val != 0:
|
|
288
|
-
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
289
|
-
else:
|
|
290
|
-
relative = CompareConst.N_A
|
|
291
|
-
result_item[start_idx + i] = diff
|
|
292
|
-
result_item[start_idx + i + 4] = relative
|
|
293
|
-
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
294
|
-
if magnitude_diff > 0.5:
|
|
295
|
-
warning_flag = True
|
|
296
|
-
else:
|
|
297
|
-
result_item[start_idx + i] = CompareConst.NONE
|
|
298
|
-
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
299
|
-
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
300
|
-
for i in range(start_idx, len(result_item)):
|
|
301
|
-
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
302
|
-
result_item[i] = f'{result_item[i]}\t'
|
|
303
|
-
|
|
304
|
-
result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
|
|
333
|
+
npu_summary_data = safe_get_value(n_dict, n_start + index, "n_dict", key=CompareConst.SUMMARY)
|
|
334
|
+
bench_summary_data = safe_get_value(b_dict, b_start + index, "b_dict", key=CompareConst.SUMMARY)
|
|
335
|
+
result_item.extend(process_summary_data(npu_summary_data))
|
|
336
|
+
result_item.extend(process_summary_data(bench_summary_data))
|
|
337
|
+
|
|
338
|
+
if dump_mode == Const.SUMMARY:
|
|
339
|
+
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
340
|
+
bench_summary_data, err_msg)
|
|
341
|
+
|
|
342
|
+
result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
|
|
305
343
|
result_item.append(err_msg)
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
result_item.append(CompareConst.NONE)
|
|
310
|
-
if all_mode_bool:
|
|
311
|
-
result_item.append(npu_data_name[n_start + index])
|
|
344
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
345
|
+
if dump_mode == Const.ALL:
|
|
346
|
+
result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
|
|
312
347
|
|
|
313
348
|
result.append(result_item)
|
|
314
349
|
|
|
315
350
|
if n_len > b_len:
|
|
316
351
|
for index in range(b_len, n_len):
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
352
|
+
try:
|
|
353
|
+
n_name = n_dict['op_name'][n_start + index]
|
|
354
|
+
n_struct = n_dict[key][index]
|
|
355
|
+
if dump_mode == Const.MD5:
|
|
356
|
+
result_item = [
|
|
357
|
+
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
358
|
+
n_struct[2], CompareConst.NAN, CompareConst.NAN
|
|
359
|
+
]
|
|
360
|
+
result.append(result_item)
|
|
361
|
+
continue
|
|
320
362
|
result_item = [
|
|
321
363
|
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
322
|
-
|
|
364
|
+
" ", " ", " ", " ", " "
|
|
323
365
|
]
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
result_item.extend(summary_data)
|
|
366
|
+
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
367
|
+
result_item.extend(summary_data)
|
|
368
|
+
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
|
|
369
|
+
result_item.extend(summary_data)
|
|
370
|
+
except IndexError as e:
|
|
371
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
372
|
+
f"n_dict is {n_dict}"
|
|
373
|
+
logger.error(err_msg)
|
|
374
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
334
375
|
|
|
335
376
|
err_msg = ""
|
|
336
377
|
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
337
378
|
result_item.append(err_msg)
|
|
338
|
-
|
|
339
|
-
if
|
|
340
|
-
result_item.
|
|
341
|
-
else:
|
|
342
|
-
result_item.append(CompareConst.NONE)
|
|
343
|
-
if all_mode_bool:
|
|
344
|
-
result_item.append(npu_data_name[n_start + index])
|
|
379
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
380
|
+
if dump_mode == Const.ALL:
|
|
381
|
+
result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
|
|
345
382
|
|
|
346
383
|
result.append(result_item)
|
|
347
384
|
|
|
348
|
-
n_num =
|
|
349
|
-
b_num =
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
n_num_output
|
|
353
|
-
|
|
354
|
-
get_accuracy_core(
|
|
355
|
-
get_accuracy_core(n_num_input
|
|
385
|
+
n_num, n_num_input, n_num_output, n_num_params, n_num_params_grad = count_struct(n_dict)
|
|
386
|
+
b_num, b_num_input, b_num_output, b_num_params, b_num_params_grad = count_struct(b_dict)
|
|
387
|
+
|
|
388
|
+
get_accuracy_core(0, n_num_input, 0, b_num_input, CompareConst.INPUT_STRUCT)
|
|
389
|
+
get_accuracy_core(n_num_input + n_num_output, n_num_params, b_num_input + b_num_output, b_num_params,
|
|
390
|
+
CompareConst.PARAMS_STRUCT)
|
|
391
|
+
get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, CompareConst.OUTPUT_STRUCT)
|
|
392
|
+
get_accuracy_core(n_num_input + n_num_output + n_num_params, n_num_params_grad,
|
|
393
|
+
b_num_input + b_num_output + b_num_params, b_num_params_grad,
|
|
394
|
+
CompareConst.PARAMS_GRAD_STRUCT)
|
|
356
395
|
|
|
357
396
|
|
|
358
|
-
def
|
|
359
|
-
|
|
397
|
+
def append_stack_info(result_item, npu_stack_info, index):
|
|
398
|
+
"""添加堆栈信息到 result_item"""
|
|
399
|
+
if npu_stack_info and index == 0:
|
|
400
|
+
result_item.extend(npu_stack_info)
|
|
401
|
+
else:
|
|
402
|
+
result_item.append(CompareConst.NONE)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
360
406
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
361
407
|
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
362
|
-
err_msg = CompareConst.NO_BENCH
|
|
363
|
-
accuracy_check_res = CompareConst.N_A
|
|
364
|
-
for index, n_name in enumerate(n_dict["op_name"]):
|
|
365
|
-
name_ele_list = n_name.split(Const.SEP)
|
|
366
|
-
if "input" in name_ele_list:
|
|
367
|
-
n_struct = n_dict["input_struct"][index]
|
|
368
|
-
else:
|
|
369
|
-
n_struct = n_dict["output_struct"][index_out]
|
|
370
|
-
index_out += 1
|
|
371
408
|
|
|
372
|
-
|
|
373
|
-
|
|
409
|
+
struct_to_index_mapping = {
|
|
410
|
+
CompareConst.INPUT_STRUCT: 0,
|
|
411
|
+
CompareConst.OUTPUT_STRUCT: 0,
|
|
412
|
+
CompareConst.PARAMS_STRUCT: 0,
|
|
413
|
+
CompareConst.PARAMS_GRAD_STRUCT: 0
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
op_name_list = n_dict.get(CompareConst.OP_NAME)
|
|
417
|
+
summary_list = n_dict.get(Const.SUMMARY)
|
|
418
|
+
data_name_list = n_dict.get('data_name')
|
|
419
|
+
op_name_reorder, summary_reorder, _ = reorder_op_x_list(op_name_list,
|
|
420
|
+
summary_list,
|
|
421
|
+
data_name_list)
|
|
422
|
+
for index, n_name in enumerate(op_name_reorder):
|
|
423
|
+
_, state = get_name_and_state(n_name)
|
|
424
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
425
|
+
if not struct_key:
|
|
426
|
+
continue
|
|
427
|
+
n_struct = safe_get_value(n_dict, struct_to_index_mapping.get(struct_key), "n_dict", key=struct_key)
|
|
428
|
+
struct_to_index_mapping[struct_key] += 1
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
432
|
+
except IndexError as e:
|
|
433
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
434
|
+
f"op_name of n_dict is {n_dict['op_name']}\n" \
|
|
435
|
+
f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
|
|
436
|
+
f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
|
|
437
|
+
logger.error(err_msg)
|
|
438
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
439
|
+
|
|
440
|
+
if dump_mode == Const.MD5:
|
|
374
441
|
result_item.extend([CompareConst.N_A] * 3)
|
|
375
|
-
|
|
376
|
-
result_item.extend(npu_stack_info)
|
|
377
|
-
else:
|
|
378
|
-
result_item.append(CompareConst.NONE)
|
|
442
|
+
append_stack_info(result_item, npu_stack_info, index)
|
|
379
443
|
result.append(result_item)
|
|
380
444
|
continue
|
|
381
|
-
if
|
|
445
|
+
if dump_mode == Const.SUMMARY:
|
|
382
446
|
result_item.extend([CompareConst.N_A] * 8)
|
|
383
|
-
|
|
447
|
+
if dump_mode == Const.ALL:
|
|
384
448
|
result_item.extend([CompareConst.N_A] * 5)
|
|
385
|
-
|
|
386
|
-
|
|
449
|
+
|
|
450
|
+
npu_summary_data = safe_get_value(summary_reorder, index, "summary_reorder")
|
|
387
451
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
452
|
+
result_item.extend(npu_summary_data)
|
|
388
453
|
result_item.extend(bench_summary_data)
|
|
454
|
+
err_msg = CompareConst.NO_BENCH
|
|
455
|
+
accuracy_check_res = CompareConst.N_A
|
|
389
456
|
result_item.append(accuracy_check_res)
|
|
390
457
|
result_item.append(err_msg)
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
else:
|
|
394
|
-
result_item.append(CompareConst.NONE)
|
|
395
|
-
if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A:
|
|
458
|
+
append_stack_info(result_item, npu_stack_info, index)
|
|
459
|
+
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
396
460
|
result_item.extend(["-1"])
|
|
397
461
|
result.append(result_item)
|
|
398
462
|
|
|
399
463
|
|
|
400
|
-
def merge_tensor(tensor_list,
|
|
464
|
+
def merge_tensor(tensor_list, dump_mode):
|
|
401
465
|
op_dict = {}
|
|
402
466
|
op_dict["op_name"] = []
|
|
403
|
-
op_dict[
|
|
404
|
-
op_dict[
|
|
405
|
-
op_dict[
|
|
406
|
-
op_dict[
|
|
467
|
+
op_dict[CompareConst.INPUT_STRUCT] = []
|
|
468
|
+
op_dict[CompareConst.KWARGS_STRUCT] = []
|
|
469
|
+
op_dict[CompareConst.OUTPUT_STRUCT] = []
|
|
470
|
+
op_dict[CompareConst.PARAMS_STRUCT] = []
|
|
471
|
+
op_dict[CompareConst.PARAMS_GRAD_STRUCT] = []
|
|
472
|
+
op_dict[Const.SUMMARY] = []
|
|
407
473
|
op_dict["stack_info"] = []
|
|
408
474
|
|
|
409
|
-
|
|
410
|
-
if all_mode_bool:
|
|
475
|
+
if dump_mode == Const.ALL:
|
|
411
476
|
op_dict["data_name"] = []
|
|
412
477
|
|
|
413
478
|
for tensor in tensor_list:
|
|
479
|
+
# A dict(len=2) with 'full_op_name' and 'full_info' is added to the tensor only if self.stack_mode is True
|
|
414
480
|
if len(tensor) == 2:
|
|
415
481
|
op_dict['stack_info'].append(tensor['full_info'])
|
|
416
482
|
break
|
|
483
|
+
|
|
417
484
|
op_dict["op_name"].append(tensor['full_op_name'])
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
op_dict["output_struct"].append((tensor['dtype'], tensor['shape']))
|
|
485
|
+
|
|
486
|
+
_, state = get_name_and_state(tensor['full_op_name'])
|
|
487
|
+
struct_key = CompareConst.STATE_TO_STRUCT_MAPPING.get(state)
|
|
488
|
+
if not struct_key:
|
|
489
|
+
continue
|
|
490
|
+
if dump_mode == Const.MD5:
|
|
491
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
426
492
|
else:
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
elif "output" in name_ele_list:
|
|
432
|
-
op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5']))
|
|
433
|
-
op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
|
|
434
|
-
|
|
435
|
-
if all_mode_bool:
|
|
493
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
494
|
+
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
495
|
+
|
|
496
|
+
if dump_mode == Const.ALL:
|
|
436
497
|
op_dict["data_name"].append(tensor['data_name'])
|
|
437
|
-
data_name = op_dict["data_name"][-1].rsplit(Const.SEP, 1)[0]
|
|
438
|
-
if data_name != "-1":
|
|
439
|
-
op_dict["op_name"][-1] = data_name
|
|
440
498
|
|
|
441
|
-
if not op_dict[
|
|
442
|
-
del op_dict[
|
|
499
|
+
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
500
|
+
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
443
501
|
return op_dict if op_dict["op_name"] else {}
|
|
444
502
|
|
|
445
503
|
|
|
504
|
+
def print_compare_ends_info():
|
|
505
|
+
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
506
|
+
logger.info('*' * total_len)
|
|
507
|
+
logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
508
|
+
logger.info('*' * total_len)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def table_value_is_valid(value: str) -> bool:
|
|
512
|
+
if not isinstance(value, str):
|
|
513
|
+
return True
|
|
514
|
+
try:
|
|
515
|
+
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
516
|
+
float(value)
|
|
517
|
+
except ValueError:
|
|
518
|
+
# otherwise, they will be considered as formular injections
|
|
519
|
+
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
520
|
+
return True
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
def get_name_and_state(name):
|
|
524
|
+
"""
|
|
525
|
+
Get api/module name and state
|
|
526
|
+
example:
|
|
527
|
+
name = 'conv2d.forward.1.input.0'
|
|
528
|
+
return: ('conv2d.forward.1.', 'input')
|
|
529
|
+
|
|
530
|
+
name = 'Functional.pad.0.backward.output.0'
|
|
531
|
+
return: ('Functional.pad.0.backward.', 'output')
|
|
532
|
+
|
|
533
|
+
state type: input, output, kwargs, parameters, parameters_grad
|
|
534
|
+
"""
|
|
535
|
+
if Const.PARAMS_GRAD in name.split(Const.SEP):
|
|
536
|
+
return name.split(Const.PARAMS_GRAD)[0], Const.PARAMS_GRAD
|
|
537
|
+
|
|
538
|
+
split = re.split(Const.REGEX_FORWARD_BACKWARD, name)
|
|
539
|
+
api = f'{split[0]}.{split[1]}.'
|
|
540
|
+
state_str = split[2]
|
|
541
|
+
match = re.match(r'^(\d+\.)?(input|output|kwargs|parameters)\..+$', state_str)
|
|
542
|
+
if not match:
|
|
543
|
+
raise CompareException(f'Invalid name string: {name}')
|
|
544
|
+
if match.group(1):
|
|
545
|
+
api = f'{api}{match.group(1)}'
|
|
546
|
+
state = match.group(2)
|
|
547
|
+
return api, state
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def reorder_op_name_list(op_name_list):
|
|
551
|
+
if not op_name_list:
|
|
552
|
+
return op_name_list
|
|
553
|
+
|
|
554
|
+
parameters = []
|
|
555
|
+
output = []
|
|
556
|
+
parameters_grad = []
|
|
557
|
+
others = []
|
|
558
|
+
for x in op_name_list:
|
|
559
|
+
state = get_name_and_state(x)[1]
|
|
560
|
+
if state == Const.PARAMS:
|
|
561
|
+
parameters.append(x)
|
|
562
|
+
elif state == Const.OUTPUT:
|
|
563
|
+
output.append(x)
|
|
564
|
+
elif state == Const.PARAMS_GRAD:
|
|
565
|
+
parameters_grad.append(x)
|
|
566
|
+
else:
|
|
567
|
+
others.append(x)
|
|
568
|
+
# 合并others, parameters, 和output,确保parameters排在output前面
|
|
569
|
+
op_name_reorder = others + parameters + output + parameters_grad
|
|
570
|
+
return op_name_reorder
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def reorder_op_x_list(op_name_list, summary_list, data_name_list):
|
|
574
|
+
"""对op_name, summary, data_name重新排序,把parameters放到input后output前,data_name由于统计量比对时,为None,单独处理"""
|
|
575
|
+
if not op_name_list or not summary_list:
|
|
576
|
+
return op_name_list, summary_list, data_name_list
|
|
577
|
+
|
|
578
|
+
index_map = {name: index for index, name in enumerate(op_name_list)}
|
|
579
|
+
|
|
580
|
+
op_name_reorder = reorder_op_name_list(op_name_list)
|
|
581
|
+
summary_reorder = [summary_list[index_map.get(name)] for name in op_name_reorder]
|
|
582
|
+
if data_name_list:
|
|
583
|
+
data_name_reorder = [data_name_list[index_map.get(name)] for name in op_name_reorder]
|
|
584
|
+
else:
|
|
585
|
+
data_name_reorder = data_name_list
|
|
586
|
+
|
|
587
|
+
return op_name_reorder, summary_reorder, data_name_reorder
|
|
588
|
+
|
|
589
|
+
|
|
446
590
|
def _compare_parser(parser):
|
|
447
591
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
448
592
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
449
593
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
450
|
-
help="<Required> The compare task result out path.",
|
|
594
|
+
help="<Required> The compare task result out path. Default path: ./output",
|
|
595
|
+
required=False, default="./output", nargs="?", const="./output")
|
|
451
596
|
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
452
597
|
help="<optional> Whether to save stack info.", required=False)
|
|
453
598
|
parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
|
|
@@ -457,8 +602,8 @@ def _compare_parser(parser):
|
|
|
457
602
|
parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True,
|
|
458
603
|
help="<optional> The cell mapping file path.", required=False)
|
|
459
604
|
parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
|
|
460
|
-
help="<optional> The api mapping file path.", required=False)
|
|
605
|
+
help="<optional> The api mapping file path.", required=False)
|
|
461
606
|
parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
|
|
462
607
|
help="<optional> The data mapping file path.", required=False)
|
|
463
|
-
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
608
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
464
609
|
help="<optional> The layer mapping file path.", required=False)
|