mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
msprobe/core/compare/utils.py
CHANGED
|
@@ -1,28 +1,45 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
1
15
|
|
|
2
16
|
import os
|
|
3
17
|
import re
|
|
18
|
+
import math
|
|
19
|
+
import zlib
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
4
22
|
import numpy as np
|
|
23
|
+
|
|
5
24
|
from msprobe.core.common.const import Const, CompareConst
|
|
6
|
-
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger
|
|
25
|
+
from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger, safe_get_value
|
|
7
26
|
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
8
27
|
|
|
9
28
|
|
|
10
29
|
def extract_json(dirname, stack_json=False):
|
|
11
30
|
json_path = ''
|
|
12
|
-
for
|
|
13
|
-
if
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
json_path = full_path
|
|
18
|
-
if not stack_json and 'stack' not in json_path:
|
|
19
|
-
break
|
|
20
|
-
if stack_json and 'stack' in json_path:
|
|
21
|
-
break
|
|
31
|
+
for filename in os.listdir(dirname):
|
|
32
|
+
target_file_name = 'stack.json' if stack_json else 'dump.json'
|
|
33
|
+
if filename == target_file_name:
|
|
34
|
+
json_path = os.path.join(dirname, filename)
|
|
35
|
+
break
|
|
22
36
|
|
|
23
37
|
# Provide robustness on invalid directory inputs
|
|
24
38
|
if not json_path:
|
|
25
|
-
|
|
39
|
+
if stack_json:
|
|
40
|
+
logger.error(f'stack.json is not found in dump dir {dirname}.')
|
|
41
|
+
else:
|
|
42
|
+
logger.error(f'dump.json is not found in dump dir {dirname}.')
|
|
26
43
|
raise CompareException(CompareException.NO_DUMP_FILE_ERROR)
|
|
27
44
|
return json_path
|
|
28
45
|
|
|
@@ -30,7 +47,7 @@ def extract_json(dirname, stack_json=False):
|
|
|
30
47
|
def check_and_return_dir_contents(dump_dir, prefix):
|
|
31
48
|
"""
|
|
32
49
|
check the given dump dir and validate files in dump dir by using the given prefix patterns to build a
|
|
33
|
-
pattern: ^{prefix}(?:0|[
|
|
50
|
+
pattern: ^{prefix}(?:0|[1-9][0-9]*)?$
|
|
34
51
|
|
|
35
52
|
Args:
|
|
36
53
|
dump_dir (str): dump dir
|
|
@@ -46,7 +63,7 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
46
63
|
check_regex_prefix_format_valid(prefix)
|
|
47
64
|
check_file_or_directory_path(dump_dir, True)
|
|
48
65
|
contents = os.listdir(dump_dir)
|
|
49
|
-
pattern = re.compile(rf'^{prefix}(?:0|[
|
|
66
|
+
pattern = re.compile(rf'^{prefix}(?:0|[1-9][0-9]*)?$')
|
|
50
67
|
for name in contents:
|
|
51
68
|
if not pattern.match(name):
|
|
52
69
|
logger.error(
|
|
@@ -59,122 +76,100 @@ def check_and_return_dir_contents(dump_dir, prefix):
|
|
|
59
76
|
|
|
60
77
|
def rename_api(npu_name, process):
|
|
61
78
|
npu_split = npu_name.split(process)
|
|
62
|
-
|
|
79
|
+
try:
|
|
80
|
+
torch_func_index, in_out = npu_split[0], npu_split[1]
|
|
81
|
+
except IndexError as error:
|
|
82
|
+
logger.error(f'{npu_name} can not be split with {process}, please check!')
|
|
83
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
63
84
|
torch_func_split = torch_func_index.rsplit(Const.SEP, 2)
|
|
64
85
|
torch_func = str(torch_func_split[0]) + str(in_out)
|
|
65
86
|
return torch_func
|
|
66
87
|
|
|
67
88
|
|
|
68
89
|
def read_op(op_data, op_name):
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
op_parsed_list += kwarg_parsed_list
|
|
81
|
-
kwarg_parsed_list.clear()
|
|
82
|
-
elif kwargs_item:
|
|
83
|
-
for kwarg in kwargs_item:
|
|
84
|
-
kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None)
|
|
85
|
-
op_parsed_list += kwarg_parsed_list
|
|
86
|
-
kwarg_parsed_list.clear()
|
|
87
|
-
if Const.OUTPUT in op_data:
|
|
88
|
-
output_item = op_data[Const.OUTPUT]
|
|
89
|
-
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
90
|
-
op_parsed_list += output_parsed_list
|
|
91
|
-
output_parsed_list.clear()
|
|
92
|
-
if Const.BACKWARD in op_name:
|
|
93
|
-
if Const.INPUT in op_data:
|
|
94
|
-
input_item = op_data[Const.INPUT]
|
|
95
|
-
input_parsed_list = op_item_parse(input_item, op_name + '.input', None)
|
|
96
|
-
op_parsed_list = input_parsed_list.copy()
|
|
97
|
-
input_parsed_list.clear()
|
|
98
|
-
if Const.OUTPUT in op_data:
|
|
99
|
-
output_item = op_data[Const.OUTPUT]
|
|
100
|
-
output_parsed_list = op_item_parse(output_item, op_name + '.output', None)
|
|
101
|
-
op_parsed_list += output_parsed_list
|
|
102
|
-
output_parsed_list.clear()
|
|
90
|
+
io_name_mapping = {
|
|
91
|
+
Const.INPUT_ARGS: '.input',
|
|
92
|
+
Const.INPUT_KWARGS: '.input',
|
|
93
|
+
Const.INPUT: '.input',
|
|
94
|
+
Const.OUTPUT: '.output'
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
op_parsed_list = []
|
|
98
|
+
for name in io_name_mapping:
|
|
99
|
+
if name in op_data:
|
|
100
|
+
op_parsed_list.extend(op_item_parse(op_data[name], op_name + io_name_mapping[name]))
|
|
103
101
|
return op_parsed_list
|
|
104
102
|
|
|
105
103
|
|
|
106
|
-
def op_item_parse(
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
if
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
parsed_item['dtype'] = 'torch.Size'
|
|
140
|
-
parsed_item['shape'] = str(item['value'])
|
|
141
|
-
parsed_item['md5'] = None
|
|
142
|
-
parsed_item['Max'] = None
|
|
143
|
-
parsed_item['Min'] = None
|
|
144
|
-
parsed_item['Mean'] = None
|
|
145
|
-
parsed_item['Norm'] = None
|
|
146
|
-
parsed_item['data_name'] = '-1'
|
|
147
|
-
item_list.append(parsed_item)
|
|
148
|
-
elif item['type'] == 'slice':
|
|
149
|
-
parsed_item['full_op_name'] = full_op_name
|
|
150
|
-
parsed_item['dtype'] = 'slice'
|
|
151
|
-
parsed_item['shape'] = str(np.shape(np.array(item['value'])))
|
|
152
|
-
parsed_item['md5'] = None
|
|
153
|
-
parsed_item['Max'] = None
|
|
154
|
-
parsed_item['Min'] = None
|
|
155
|
-
parsed_item['Mean'] = None
|
|
156
|
-
parsed_item['Norm'] = None
|
|
157
|
-
parsed_item['data_name'] = '-1'
|
|
158
|
-
item_list.append(parsed_item)
|
|
159
|
-
else:
|
|
160
|
-
parsed_item['full_op_name'] = full_op_name
|
|
161
|
-
parsed_item['dtype'] = str(type(item['value']))
|
|
162
|
-
parsed_item['shape'] = '[]'
|
|
163
|
-
parsed_item['md5'] = None
|
|
164
|
-
parsed_item['Max'] = item['value']
|
|
165
|
-
parsed_item['Min'] = item['value']
|
|
166
|
-
parsed_item['Mean'] = item['value']
|
|
167
|
-
parsed_item['Norm'] = item['value']
|
|
168
|
-
parsed_item['data_name'] = '-1'
|
|
169
|
-
item_list.append(parsed_item)
|
|
170
|
-
else:
|
|
171
|
-
resolve_api_special_parameters(item, full_op_name, item_list)
|
|
172
|
-
else:
|
|
173
|
-
for j, item_spec in enumerate(item):
|
|
174
|
-
op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False)
|
|
104
|
+
def op_item_parse(op_data, op_name: str, depth: int = 0) -> list:
|
|
105
|
+
default_item = {
|
|
106
|
+
'full_op_name': op_name,
|
|
107
|
+
'type': None,
|
|
108
|
+
'Max': None,
|
|
109
|
+
'Min': None,
|
|
110
|
+
'Mean': None,
|
|
111
|
+
'Norm': None,
|
|
112
|
+
'dtype': None,
|
|
113
|
+
'shape': None,
|
|
114
|
+
'md5': None,
|
|
115
|
+
'value': None,
|
|
116
|
+
'data_name': '-1'
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
if depth > Const.MAX_DEPTH:
|
|
120
|
+
logger.error(f'parse of api/module of {op_name} exceeds the recursion limit.')
|
|
121
|
+
raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
|
|
122
|
+
|
|
123
|
+
if op_data is None:
|
|
124
|
+
return [default_item]
|
|
125
|
+
elif not op_data:
|
|
126
|
+
return []
|
|
127
|
+
|
|
128
|
+
item_list = []
|
|
129
|
+
if isinstance(op_data, list):
|
|
130
|
+
for i, data in enumerate(op_data):
|
|
131
|
+
item_list.extend(op_item_parse(data, op_name + Const.SEP + str(i), depth + 1))
|
|
132
|
+
elif isinstance(op_data, dict):
|
|
133
|
+
if is_leaf_data(op_data):
|
|
134
|
+
return [gen_op_item(op_data, op_name)]
|
|
135
|
+
for sub_name, sub_data in op_data.items():
|
|
136
|
+
item_list.extend(op_item_parse(sub_data, op_name + Const.SEP + str(sub_name), depth + 1))
|
|
175
137
|
return item_list
|
|
176
138
|
|
|
177
139
|
|
|
140
|
+
def is_leaf_data(op_data):
|
|
141
|
+
return 'type' in op_data and isinstance(op_data['type'], str)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def gen_op_item(op_data, op_name):
|
|
145
|
+
op_item = {}
|
|
146
|
+
op_item.update(op_data)
|
|
147
|
+
op_item['full_op_name'] = op_name
|
|
148
|
+
op_item['data_name'] = op_data.get('data_name', '-1')
|
|
149
|
+
|
|
150
|
+
params = ['Max', 'Min', 'Mean', 'Norm']
|
|
151
|
+
for i in params:
|
|
152
|
+
if i not in op_item:
|
|
153
|
+
op_item[i] = None
|
|
154
|
+
|
|
155
|
+
if not op_item.get('dtype'):
|
|
156
|
+
if op_item.get('type') == 'torch.Size':
|
|
157
|
+
op_item['dtype'] = op_data.get('type')
|
|
158
|
+
op_item['shape'] = str(op_data.get('value'))
|
|
159
|
+
elif op_item.get('type') == 'slice':
|
|
160
|
+
op_item['dtype'] = op_data.get('type')
|
|
161
|
+
op_item['shape'] = str(np.shape(np.array(op_data.get('value'))))
|
|
162
|
+
else:
|
|
163
|
+
op_item['dtype'] = str(type(op_data.get('value')))
|
|
164
|
+
op_item['shape'] = '[]'
|
|
165
|
+
for i in params:
|
|
166
|
+
op_item[i] = op_data.get('value')
|
|
167
|
+
if not op_item.get('md5'):
|
|
168
|
+
op_item['md5'] = f"{zlib.crc32(str(op_data.get('value', '')).encode()):08x}"
|
|
169
|
+
|
|
170
|
+
return op_item
|
|
171
|
+
|
|
172
|
+
|
|
178
173
|
def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
179
174
|
"""
|
|
180
175
|
Function Description:
|
|
@@ -206,139 +201,196 @@ def resolve_api_special_parameters(data_dict, full_op_name, item_list):
|
|
|
206
201
|
item_list.append(parsed_item)
|
|
207
202
|
|
|
208
203
|
|
|
209
|
-
def
|
|
204
|
+
def process_summary_data(summary_data):
|
|
205
|
+
"""处理summary_data中的nan值,返回处理后的列表"""
|
|
206
|
+
return [CompareConst.NAN if isinstance(x, float) and math.isnan(x) else x for x in summary_data]
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_rela_diff_summary_mode(result_item, npu_summary_data, bench_summary_data, err_msg):
|
|
210
|
+
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
211
|
+
warning_flag = False
|
|
212
|
+
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
213
|
+
if all(isinstance(val, (float, int)) and not isinstance(val, bool) for val in [npu_val, bench_val]):
|
|
214
|
+
diff = npu_val - bench_val
|
|
215
|
+
if math.isnan(diff):
|
|
216
|
+
diff = CompareConst.NAN
|
|
217
|
+
relative = CompareConst.NAN
|
|
218
|
+
else:
|
|
219
|
+
if bench_val != 0:
|
|
220
|
+
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
221
|
+
else:
|
|
222
|
+
relative = CompareConst.N_A
|
|
223
|
+
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + CompareConst.EPSILON)
|
|
224
|
+
if magnitude_diff > CompareConst.MAGNITUDE:
|
|
225
|
+
warning_flag = True
|
|
226
|
+
result_item[start_idx + i] = diff
|
|
227
|
+
result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = relative
|
|
228
|
+
else:
|
|
229
|
+
result_item[start_idx + i] = CompareConst.N_A
|
|
230
|
+
result_item[start_idx + i + CompareConst.STATISTICS_INDICATOR_NUM] = CompareConst.N_A
|
|
231
|
+
|
|
232
|
+
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
233
|
+
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
234
|
+
for i in range(start_idx, len(result_item)):
|
|
235
|
+
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
236
|
+
result_item[i] = f'{result_item[i]}\t'
|
|
237
|
+
return result_item, accuracy_check, err_msg
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@dataclass
|
|
241
|
+
class ApiItemInfo:
|
|
242
|
+
name: str
|
|
243
|
+
struct: tuple
|
|
244
|
+
stack_info: list
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def stack_column_process(result_item, has_stack, index, key, npu_stack_info):
|
|
248
|
+
if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT:
|
|
249
|
+
result_item.extend(npu_stack_info)
|
|
250
|
+
else:
|
|
251
|
+
result_item.append(CompareConst.NONE)
|
|
252
|
+
return result_item
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def result_item_init(n_info, b_info, dump_mode):
|
|
256
|
+
n_len = len(n_info.struct)
|
|
257
|
+
b_len = len(b_info.struct)
|
|
258
|
+
struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1)
|
|
259
|
+
if struct_long_enough:
|
|
260
|
+
result_item = [
|
|
261
|
+
n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1]
|
|
262
|
+
]
|
|
263
|
+
if dump_mode == Const.MD5:
|
|
264
|
+
md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF
|
|
265
|
+
result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result])
|
|
266
|
+
elif dump_mode == Const.SUMMARY:
|
|
267
|
+
result_item.extend([" "] * 8)
|
|
268
|
+
else:
|
|
269
|
+
result_item.extend([" "] * 5)
|
|
270
|
+
else:
|
|
271
|
+
err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \
|
|
272
|
+
f"npu_info_struct is {n_info.struct}\n" \
|
|
273
|
+
f"bench_info_struct is {b_info.struct}"
|
|
274
|
+
logger.error(err_msg)
|
|
275
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
|
|
276
|
+
return result_item
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_accuracy(result, n_dict, b_dict, dump_mode):
|
|
210
280
|
def get_accuracy_core(n_start, n_len, b_start, b_len, key):
|
|
211
281
|
min_len = min(n_len, b_len)
|
|
212
282
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
213
283
|
bench_stack_info = b_dict.get("stack_info", None)
|
|
214
284
|
has_stack = npu_stack_info and bench_stack_info
|
|
215
285
|
|
|
216
|
-
|
|
217
|
-
if all_mode_bool:
|
|
286
|
+
if dump_mode == Const.ALL:
|
|
218
287
|
npu_data_name = n_dict.get("data_name", None)
|
|
219
288
|
bench_data_name = b_dict.get("data_name", None)
|
|
220
289
|
|
|
221
290
|
for index in range(min_len):
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
b_struct = b_dict[key][index]
|
|
291
|
+
n_name = safe_get_value(n_dict, n_start + index, "n_dict", key="op_name")
|
|
292
|
+
b_name = safe_get_value(b_dict, b_start + index, "b_dict", key="op_name")
|
|
293
|
+
n_struct = safe_get_value(n_dict, index, "n_dict", key=key)
|
|
294
|
+
b_struct = safe_get_value(b_dict, index, "b_dict", key=key)
|
|
227
295
|
err_msg = ""
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
result_item.append(CompareConst.NONE)
|
|
296
|
+
|
|
297
|
+
npu_info = ApiItemInfo(n_name, n_struct, npu_stack_info)
|
|
298
|
+
bench_info = ApiItemInfo(b_name, b_struct, bench_stack_info)
|
|
299
|
+
result_item = result_item_init(npu_info, bench_info, dump_mode)
|
|
300
|
+
|
|
301
|
+
if dump_mode == Const.MD5:
|
|
302
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
236
303
|
result.append(result_item)
|
|
237
304
|
continue
|
|
238
305
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
result_item.extend(bench_summary_data)
|
|
250
|
-
|
|
251
|
-
if summary_compare:
|
|
252
|
-
start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
|
|
253
|
-
warning_flag = False
|
|
254
|
-
for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
|
|
255
|
-
if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
|
|
256
|
-
diff = npu_val - bench_val
|
|
257
|
-
if bench_val != 0:
|
|
258
|
-
relative = str(abs((diff / bench_val) * 100)) + '%'
|
|
259
|
-
else:
|
|
260
|
-
relative = "N/A"
|
|
261
|
-
result_item[start_idx + i] = diff
|
|
262
|
-
result_item[start_idx + i + 4] = relative
|
|
263
|
-
magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
|
|
264
|
-
if magnitude_diff > 0.5:
|
|
265
|
-
warning_flag = True
|
|
266
|
-
else:
|
|
267
|
-
result_item[start_idx + i] = CompareConst.NONE
|
|
268
|
-
accuracy_check = CompareConst.WARNING if warning_flag else ""
|
|
269
|
-
err_msg += "Need double check api accuracy." if warning_flag else ""
|
|
270
|
-
for i in range(start_idx, len(result_item)):
|
|
271
|
-
if str(result_item[i]) in ('inf', '-inf', 'nan'):
|
|
272
|
-
result_item[i] = f'{result_item[i]}\t'
|
|
273
|
-
|
|
274
|
-
result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES)
|
|
306
|
+
npu_summary_data = safe_get_value(n_dict, n_start + index, "n_dict", key=CompareConst.SUMMARY)
|
|
307
|
+
bench_summary_data = safe_get_value(b_dict, b_start + index, "b_dict", key=CompareConst.SUMMARY)
|
|
308
|
+
result_item.extend(process_summary_data(npu_summary_data))
|
|
309
|
+
result_item.extend(process_summary_data(bench_summary_data))
|
|
310
|
+
|
|
311
|
+
if dump_mode == Const.SUMMARY:
|
|
312
|
+
result_item, accuracy_check, err_msg = get_rela_diff_summary_mode(result_item, npu_summary_data,
|
|
313
|
+
bench_summary_data, err_msg)
|
|
314
|
+
|
|
315
|
+
result_item.append(accuracy_check if dump_mode == Const.SUMMARY else CompareConst.ACCURACY_CHECK_YES)
|
|
275
316
|
result_item.append(err_msg)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
result_item.append(CompareConst.NONE)
|
|
280
|
-
if all_mode_bool:
|
|
281
|
-
result_item.append(npu_data_name[n_start + index])
|
|
317
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
318
|
+
if dump_mode == Const.ALL:
|
|
319
|
+
result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
|
|
282
320
|
|
|
283
321
|
result.append(result_item)
|
|
284
322
|
|
|
285
323
|
if n_len > b_len:
|
|
286
324
|
for index in range(b_len, n_len):
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
325
|
+
try:
|
|
326
|
+
n_name = n_dict['op_name'][n_start + index]
|
|
327
|
+
n_struct = n_dict[key][index]
|
|
328
|
+
if dump_mode == Const.MD5:
|
|
329
|
+
result_item = [
|
|
330
|
+
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
331
|
+
n_struct[2], CompareConst.NAN, CompareConst.NAN
|
|
332
|
+
]
|
|
333
|
+
result.append(result_item)
|
|
334
|
+
continue
|
|
335
|
+
result_item = [
|
|
336
|
+
n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, n_struct[1], CompareConst.NAN,
|
|
337
|
+
" ", " ", " ", " ", " "
|
|
338
|
+
]
|
|
339
|
+
summary_data = n_dict.get(CompareConst.SUMMARY)[n_start + index]
|
|
340
|
+
result_item.extend(summary_data)
|
|
341
|
+
summary_data = [CompareConst.NAN for _ in range(len(n_dict.get(CompareConst.SUMMARY)[0]))]
|
|
342
|
+
result_item.extend(summary_data)
|
|
343
|
+
except IndexError as e:
|
|
344
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
345
|
+
f"n_dict is {n_dict}"
|
|
346
|
+
logger.error(err_msg)
|
|
347
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
300
348
|
|
|
301
349
|
err_msg = ""
|
|
302
350
|
result_item.append(CompareConst.ACCURACY_CHECK_YES)
|
|
303
351
|
result_item.append(err_msg)
|
|
304
|
-
|
|
305
|
-
if
|
|
306
|
-
result_item.
|
|
307
|
-
else:
|
|
308
|
-
result_item.append(CompareConst.NONE)
|
|
309
|
-
if all_mode_bool:
|
|
310
|
-
result_item.append(npu_data_name[n_start + index])
|
|
352
|
+
result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info)
|
|
353
|
+
if dump_mode == Const.ALL:
|
|
354
|
+
result_item.append(safe_get_value(npu_data_name, n_start + index, "npu_data_name"))
|
|
311
355
|
|
|
312
356
|
result.append(result_item)
|
|
313
357
|
|
|
314
358
|
n_num = len(n_dict['op_name'])
|
|
315
359
|
b_num = len(b_dict['op_name'])
|
|
316
|
-
n_num_input = len([name for name in n_dict['op_name']
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
n_num_output = n_num - n_num_input
|
|
321
|
-
b_num_output = b_num - b_num_input
|
|
360
|
+
n_num_input = len([name for name in n_dict['op_name']
|
|
361
|
+
if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
|
|
362
|
+
b_num_input = len([name for name in b_dict['op_name']
|
|
363
|
+
if Const.INPUT in name.split(Const.SEP) or Const.KWARGS in name.split(Const.SEP)])
|
|
364
|
+
n_num_output = n_num - n_num_input
|
|
365
|
+
b_num_output = b_num - b_num_input
|
|
322
366
|
get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct')
|
|
323
|
-
get_accuracy_core(n_num_input,
|
|
324
|
-
get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct')
|
|
367
|
+
get_accuracy_core(n_num_input, n_num_output, b_num_input, b_num_output, 'output_struct')
|
|
325
368
|
|
|
326
369
|
|
|
327
|
-
def get_un_match_accuracy(result, n_dict,
|
|
370
|
+
def get_un_match_accuracy(result, n_dict, dump_mode):
|
|
328
371
|
index_out = 0
|
|
329
372
|
npu_stack_info = n_dict.get("stack_info", None)
|
|
330
373
|
bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A
|
|
331
374
|
err_msg = CompareConst.NO_BENCH
|
|
332
375
|
accuracy_check_res = CompareConst.N_A
|
|
333
376
|
for index, n_name in enumerate(n_dict["op_name"]):
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
377
|
+
name_ele_list = n_name.split(Const.SEP)
|
|
378
|
+
if Const.INPUT in name_ele_list or Const.KWARGS in name_ele_list:
|
|
379
|
+
n_struct = safe_get_value(n_dict, index, "n_dict", key=CompareConst.INPUT_STRUCT)
|
|
380
|
+
if Const.OUTPUT in name_ele_list:
|
|
381
|
+
n_struct = safe_get_value(n_dict, index_out, "n_dict", key=CompareConst.OUTPUT_STRUCT)
|
|
338
382
|
index_out += 1
|
|
339
383
|
|
|
340
|
-
|
|
341
|
-
|
|
384
|
+
try:
|
|
385
|
+
result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape]
|
|
386
|
+
except IndexError as e:
|
|
387
|
+
err_msg = "index out of bounds error occurs, please check!\n" \
|
|
388
|
+
f"op_name of n_dict is {n_dict['op_name']}\n" \
|
|
389
|
+
f"input_struct of n_dict is {n_dict[CompareConst.INPUT_STRUCT]}\n" \
|
|
390
|
+
f"output_struct of n_dict is {n_dict[CompareConst.OUTPUT_STRUCT]}"
|
|
391
|
+
logger.error(err_msg)
|
|
392
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e
|
|
393
|
+
if dump_mode == Const.MD5:
|
|
342
394
|
result_item.extend([CompareConst.N_A] * 3)
|
|
343
395
|
if npu_stack_info and index == 0:
|
|
344
396
|
result_item.extend(npu_stack_info)
|
|
@@ -346,11 +398,11 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
|
346
398
|
result_item.append(CompareConst.NONE)
|
|
347
399
|
result.append(result_item)
|
|
348
400
|
continue
|
|
349
|
-
if
|
|
401
|
+
if dump_mode == Const.SUMMARY:
|
|
350
402
|
result_item.extend([CompareConst.N_A] * 8)
|
|
351
403
|
else:
|
|
352
404
|
result_item.extend([CompareConst.N_A] * 5)
|
|
353
|
-
npu_summary_data = n_dict
|
|
405
|
+
npu_summary_data = safe_get_value(n_dict, index, "n_dict", key=CompareConst.SUMMARY)
|
|
354
406
|
result_item.extend(npu_summary_data)
|
|
355
407
|
bench_summary_data = [CompareConst.N_A] * 4
|
|
356
408
|
result_item.extend(bench_summary_data)
|
|
@@ -360,22 +412,21 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare):
|
|
|
360
412
|
result_item.extend(npu_stack_info)
|
|
361
413
|
else:
|
|
362
414
|
result_item.append(CompareConst.NONE)
|
|
363
|
-
if
|
|
415
|
+
if dump_mode == Const.ALL and result_item[1] == CompareConst.N_A:
|
|
364
416
|
result_item.extend(["-1"])
|
|
365
417
|
result.append(result_item)
|
|
366
418
|
|
|
367
419
|
|
|
368
|
-
def merge_tensor(tensor_list,
|
|
420
|
+
def merge_tensor(tensor_list, dump_mode):
|
|
369
421
|
op_dict = {}
|
|
370
422
|
op_dict["op_name"] = []
|
|
371
|
-
op_dict[
|
|
372
|
-
op_dict[
|
|
373
|
-
op_dict[
|
|
374
|
-
op_dict[
|
|
423
|
+
op_dict[CompareConst.INPUT_STRUCT] = []
|
|
424
|
+
op_dict[CompareConst.KWARGS_STRUCT] = []
|
|
425
|
+
op_dict[CompareConst.OUTPUT_STRUCT] = []
|
|
426
|
+
op_dict[Const.SUMMARY] = []
|
|
375
427
|
op_dict["stack_info"] = []
|
|
376
428
|
|
|
377
|
-
|
|
378
|
-
if all_mode_bool:
|
|
429
|
+
if dump_mode == Const.ALL:
|
|
379
430
|
op_dict["data_name"] = []
|
|
380
431
|
|
|
381
432
|
for tensor in tensor_list:
|
|
@@ -383,36 +434,45 @@ def merge_tensor(tensor_list, summary_compare, md5_compare):
|
|
|
383
434
|
op_dict['stack_info'].append(tensor['full_info'])
|
|
384
435
|
break
|
|
385
436
|
op_dict["op_name"].append(tensor['full_op_name'])
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']])
|
|
437
|
+
name_ele_list = tensor['full_op_name'].split(Const.SEP)
|
|
438
|
+
name_to_struct_mapping = {
|
|
439
|
+
Const.INPUT: CompareConst.INPUT_STRUCT,
|
|
440
|
+
Const.KWARGS: CompareConst.KWARGS_STRUCT,
|
|
441
|
+
Const.OUTPUT: CompareConst.OUTPUT_STRUCT
|
|
442
|
+
}
|
|
443
|
+
for name_key, struct_key in name_to_struct_mapping.items():
|
|
444
|
+
if name_key in name_ele_list:
|
|
445
|
+
if dump_mode == Const.MD5:
|
|
446
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5]))
|
|
447
|
+
else:
|
|
448
|
+
op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE]))
|
|
449
|
+
break
|
|
450
|
+
op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]])
|
|
402
451
|
|
|
403
|
-
if
|
|
452
|
+
if dump_mode == Const.ALL:
|
|
404
453
|
op_dict["data_name"].append(tensor['data_name'])
|
|
454
|
+
data_name = safe_get_value(op_dict, -1, "op_dict", key="data_name").rsplit(Const.SEP, 1)[0]
|
|
455
|
+
if data_name != "-1":
|
|
456
|
+
op_dict["op_name"][-1] = data_name
|
|
405
457
|
|
|
406
|
-
if not op_dict[
|
|
407
|
-
del op_dict[
|
|
458
|
+
if not op_dict[CompareConst.KWARGS_STRUCT]:
|
|
459
|
+
del op_dict[CompareConst.KWARGS_STRUCT]
|
|
408
460
|
return op_dict if op_dict["op_name"] else {}
|
|
409
461
|
|
|
410
462
|
|
|
463
|
+
def print_compare_ends_info():
|
|
464
|
+
total_len = len(CompareConst.COMPARE_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
|
|
465
|
+
logger.info('*' * total_len)
|
|
466
|
+
logger.info(f"*{CompareConst.COMPARE_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
|
|
467
|
+
logger.info('*' * total_len)
|
|
468
|
+
|
|
469
|
+
|
|
411
470
|
def _compare_parser(parser):
|
|
412
471
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
413
|
-
help="<Required> The compare input path, a dict json.",
|
|
472
|
+
help="<Required> The compare input path, a dict json.", required=True)
|
|
414
473
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
415
|
-
help="<Required> The compare task result out path.",
|
|
474
|
+
help="<Required> The compare task result out path. Default path: ./output",
|
|
475
|
+
required=False, default="./output", nargs="?", const="./output")
|
|
416
476
|
parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true",
|
|
417
477
|
help="<optional> Whether to save stack info.", required=False)
|
|
418
478
|
parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true",
|
|
@@ -423,8 +483,7 @@ def _compare_parser(parser):
|
|
|
423
483
|
help="<optional> The cell mapping file path.", required=False)
|
|
424
484
|
parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True,
|
|
425
485
|
help="<optional> The api mapping file path.", required=False)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
486
|
+
parser.add_argument("-dm", "--data_mapping", dest="data_mapping", type=str,
|
|
487
|
+
help="<optional> The data mapping file path.", required=False)
|
|
488
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
489
|
+
help="<optional> The layer mapping file path.", required=False)
|