mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +84 -18
- msprobe/__init__.py +16 -1
- msprobe/config.json +1 -5
- msprobe/core/advisor/advisor.py +16 -11
- msprobe/core/advisor/advisor_const.py +6 -7
- msprobe/core/advisor/advisor_result.py +12 -12
- msprobe/core/common/const.py +164 -3
- msprobe/core/common/exceptions.py +26 -4
- msprobe/core/common/file_utils.py +196 -27
- msprobe/core/common/inplace_op_checker.py +53 -0
- msprobe/core/common/inplace_ops.yaml +251 -0
- msprobe/core/common/log.py +46 -18
- msprobe/core/common/utils.py +308 -209
- msprobe/core/common_config.py +60 -38
- msprobe/core/compare/acc_compare.py +332 -94
- msprobe/core/compare/check.py +104 -22
- msprobe/core/compare/compare_cli.py +42 -5
- msprobe/core/compare/highlight.py +162 -57
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/multiprocessing_compute.py +33 -8
- msprobe/core/compare/npy_compare.py +73 -29
- msprobe/core/compare/utils.py +306 -247
- msprobe/core/data_dump/data_collector.py +44 -43
- msprobe/core/data_dump/data_processor/base.py +88 -35
- msprobe/core/data_dump/data_processor/factory.py +20 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
- msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
- msprobe/core/data_dump/json_writer.py +63 -42
- msprobe/core/data_dump/scope.py +143 -48
- msprobe/core/grad_probe/constant.py +31 -13
- msprobe/core/grad_probe/grad_compare.py +20 -4
- msprobe/core/grad_probe/utils.py +44 -3
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +29 -9
- msprobe/docs/02.config_introduction.md +83 -84
- msprobe/docs/03.config_examples.md +3 -20
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +143 -13
- msprobe/docs/06.data_dump_MindSpore.md +197 -88
- msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
- msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
- msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
- msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
- msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
- msprobe/docs/17.grad_probe.md +19 -22
- msprobe/docs/18.online_dispatch.md +89 -0
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/ms_dump.png +0 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/docs/img/pt_dump.png +0 -0
- msprobe/mindspore/__init__.py +16 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
- msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +58 -13
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +60 -5
- msprobe/mindspore/compare/distributed_compare.py +15 -28
- msprobe/mindspore/compare/ms_compare.py +319 -158
- msprobe/mindspore/compare/ms_graph_compare.py +99 -49
- msprobe/mindspore/debugger/debugger_config.py +20 -14
- msprobe/mindspore/debugger/precision_debugger.py +43 -13
- msprobe/mindspore/dump/dump_tool_factory.py +18 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
- msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
- msprobe/mindspore/dump/jit_dump.py +56 -20
- msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
- msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
- msprobe/mindspore/free_benchmark/common/config.py +15 -0
- msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
- msprobe/mindspore/free_benchmark/common/utils.py +37 -8
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
- msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
- msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
- msprobe/mindspore/grad_probe/global_context.py +44 -14
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +75 -150
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +23 -3
- msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
- msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
- msprobe/pytorch/bench_functions/__init__.py +18 -3
- msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
- msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
- msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
- msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
- msprobe/pytorch/bench_functions/linear.py +15 -0
- msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
- msprobe/pytorch/bench_functions/rms_norm.py +15 -0
- msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
- msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
- msprobe/pytorch/bench_functions/swiglu.py +29 -6
- msprobe/pytorch/common/__init__.py +15 -0
- msprobe/pytorch/common/log.py +18 -6
- msprobe/pytorch/common/parse_json.py +31 -16
- msprobe/pytorch/common/utils.py +96 -40
- msprobe/pytorch/compare/distributed_compare.py +13 -14
- msprobe/pytorch/compare/match.py +15 -0
- msprobe/pytorch/compare/pt_compare.py +44 -10
- msprobe/pytorch/debugger/debugger_config.py +69 -52
- msprobe/pytorch/debugger/precision_debugger.py +72 -24
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/__init__.py +20 -5
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +43 -0
- msprobe/pytorch/free_benchmark/common/params.py +23 -1
- msprobe/pytorch/free_benchmark/common/utils.py +43 -5
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
- msprobe/pytorch/free_benchmark/main.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
- msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
- msprobe/pytorch/function_factory.py +17 -2
- msprobe/pytorch/functional/module_dump.py +84 -0
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +16 -1
- msprobe/pytorch/hook_module/api_registry.py +13 -8
- msprobe/pytorch/hook_module/hook_module.py +17 -19
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/utils.py +4 -6
- msprobe/pytorch/hook_module/wrap_aten.py +12 -11
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
- msprobe/pytorch/hook_module/wrap_functional.py +21 -20
- msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
- msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
- msprobe/pytorch/hook_module/wrap_torch.py +4 -6
- msprobe/pytorch/hook_module/wrap_vf.py +4 -6
- msprobe/pytorch/module_processer.py +18 -6
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +38 -48
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +60 -39
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
- msprobe/pytorch/online_dispatch/utils.py +48 -23
- msprobe/pytorch/parse.py +15 -0
- msprobe/pytorch/parse_tool/cli.py +5 -6
- msprobe/pytorch/parse_tool/lib/compare.py +19 -26
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
- msprobe/pytorch/parse_tool/lib/utils.py +40 -55
- msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
- msprobe/pytorch/pt_config.py +192 -40
- msprobe/pytorch/service.py +110 -35
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
- msprobe/docs/04.acl_config_examples.md +0 -76
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
- msprobe/pytorch/functional/dump_module.py +0 -39
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
- /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
|
@@ -1,29 +1,170 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
|
-
import
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
17
|
+
import re
|
|
18
|
+
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
8
25
|
from msprobe.core.common.exceptions import FileCheckException
|
|
26
|
+
from msprobe.core.common.file_utils import (FileOpen, create_directory, load_json,
|
|
27
|
+
load_npy, load_yaml)
|
|
28
|
+
from msprobe.core.common.log import logger
|
|
29
|
+
from msprobe.core.common.utils import (CompareException, check_compare_param,
|
|
30
|
+
check_configuration_param,
|
|
31
|
+
get_dump_mode, set_dump_path, check_op_str_pattern_valid)
|
|
32
|
+
from msprobe.core.compare.check import dtype_mapping
|
|
9
33
|
from msprobe.core.compare.acc_compare import Comparator
|
|
10
|
-
from msprobe.core.compare.
|
|
34
|
+
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
11
35
|
|
|
12
36
|
|
|
13
37
|
class MSComparator(Comparator):
|
|
14
|
-
|
|
38
|
+
"""
|
|
39
|
+
用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
|
|
40
|
+
cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
|
|
41
|
+
api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
|
|
42
|
+
data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
|
|
43
|
+
is_cross_framework: 是否跨框架。
|
|
44
|
+
"""
|
|
45
|
+
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
|
|
15
46
|
self.frame_name = MSComparator.__name__
|
|
16
47
|
self.cell_mapping = cell_mapping
|
|
17
48
|
self.api_mapping = api_mapping
|
|
18
|
-
self.
|
|
49
|
+
self.data_mapping = data_mapping
|
|
50
|
+
if data_mapping:
|
|
51
|
+
self.cross_frame = is_cross_framework
|
|
52
|
+
else:
|
|
53
|
+
self.cross_frame = cell_mapping is not None or api_mapping is not None
|
|
19
54
|
self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
|
|
20
55
|
self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
|
|
21
56
|
if api_mapping is not None:
|
|
22
57
|
self.ms_to_pt_mapping = self.load_internal_api()
|
|
58
|
+
|
|
59
|
+
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
60
|
+
self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
|
|
61
|
+
elif isinstance(self.data_mapping, dict):
|
|
62
|
+
self.data_mapping_dict = self.data_mapping
|
|
63
|
+
else:
|
|
64
|
+
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
65
|
+
f"{type(self.data_mapping)}")
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def calc_accuracy(cls, result_df, dump_mode, header):
|
|
69
|
+
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
70
|
+
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
71
|
+
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
72
|
+
|
|
73
|
+
def calc_summary_diff(data_type: str):
|
|
74
|
+
def type_check(val):
|
|
75
|
+
check_series = pd.Series(False, index=val.index)
|
|
76
|
+
val_str = val.astype(str)
|
|
77
|
+
check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
|
|
78
|
+
return check_series
|
|
23
79
|
|
|
80
|
+
def get_number(val):
|
|
81
|
+
return pd.to_numeric(val.astype(str), errors='coerce')
|
|
82
|
+
|
|
83
|
+
ms_val = result_df['NPU ' + data_type]
|
|
84
|
+
pt_val = result_df['Bench ' + data_type]
|
|
85
|
+
diff_name = data_type.capitalize() + ' diff'
|
|
86
|
+
rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
|
|
87
|
+
condition_na = ~type_check(ms_val) | ~type_check(pt_val)
|
|
88
|
+
result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
|
|
89
|
+
result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
|
|
90
|
+
condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
|
|
91
|
+
condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
|
|
92
|
+
result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
|
|
93
|
+
condition_pt_zero = pt_val == 0
|
|
94
|
+
result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
|
|
95
|
+
condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
|
|
96
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
|
|
97
|
+
pt_val[condition_ref_err] * 100)
|
|
98
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
|
|
99
|
+
.abs().astype(str) + '%')
|
|
100
|
+
magnitude = get_number(result_df[diff_name]).abs() / (
|
|
101
|
+
pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
|
|
102
|
+
return magnitude > CompareConst.MAGNITUDE
|
|
103
|
+
|
|
104
|
+
if dump_mode == Const.MD5:
|
|
105
|
+
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
106
|
+
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
107
|
+
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
108
|
+
elif dump_mode == Const.SUMMARY:
|
|
109
|
+
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
110
|
+
warning_flag = pd.DataFrame(warning_list).all()
|
|
111
|
+
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
112
|
+
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
113
|
+
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
114
|
+
else:
|
|
115
|
+
fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
116
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
117
|
+
CompareConst.ERROR_MESSAGE]
|
|
118
|
+
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
119
|
+
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
120
|
+
return result_df[header]
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def make_result_df(cls, result, stack_mode, dump_mode):
|
|
124
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode]
|
|
125
|
+
|
|
126
|
+
if stack_mode:
|
|
127
|
+
header.append(CompareConst.STACK)
|
|
128
|
+
if dump_mode == Const.ALL:
|
|
129
|
+
header.append(CompareConst.DATA_NAME)
|
|
130
|
+
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
131
|
+
'op_name_y': CompareConst.BENCH_NAME,
|
|
132
|
+
'dtype_x': CompareConst.NPU_DTYPE,
|
|
133
|
+
'dtype_y': CompareConst.BENCH_DTYPE,
|
|
134
|
+
'shape_x': CompareConst.NPU_SHAPE,
|
|
135
|
+
'shape_y': CompareConst.BENCH_SHAPE,
|
|
136
|
+
'md5_x': CompareConst.NPU_MD5,
|
|
137
|
+
'md5_y': CompareConst.BENCH_MD5,
|
|
138
|
+
'data_name_x': CompareConst.DATA_NAME,
|
|
139
|
+
'stack_info_x': CompareConst.STACK}, inplace=True)
|
|
140
|
+
|
|
141
|
+
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
142
|
+
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
143
|
+
CompareConst.BENCH_NORM]
|
|
144
|
+
def set_summary(summary):
|
|
145
|
+
if summary == CompareConst.N_A:
|
|
146
|
+
return [CompareConst.N_A] * 4
|
|
147
|
+
summary_list = []
|
|
148
|
+
for i in summary:
|
|
149
|
+
if i is None:
|
|
150
|
+
summary_list.append(CompareConst.N_A)
|
|
151
|
+
elif str(i).lower() == 'nan':
|
|
152
|
+
summary_list.append(CompareConst.NAN)
|
|
153
|
+
else:
|
|
154
|
+
summary_list.append(i)
|
|
155
|
+
return summary_list
|
|
156
|
+
|
|
157
|
+
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
158
|
+
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
159
|
+
result_df = pd.DataFrame(columns=header)
|
|
160
|
+
for h in header:
|
|
161
|
+
if h in result.columns:
|
|
162
|
+
result_df[h] = result[h]
|
|
163
|
+
return cls.calc_accuracy(result_df, dump_mode, header)
|
|
164
|
+
|
|
24
165
|
def load_internal_api(self):
|
|
25
166
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
26
|
-
yaml_path = os.path.join(cur_path,
|
|
167
|
+
yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
|
|
27
168
|
return load_yaml(yaml_path)
|
|
28
169
|
|
|
29
170
|
def load_mapping_file(self, mapping_file):
|
|
@@ -34,171 +175,184 @@ class MSComparator(Comparator):
|
|
|
34
175
|
return mapping_dict
|
|
35
176
|
|
|
36
177
|
def process_cell_mapping(self, npu_op_name):
|
|
37
|
-
npu_op_name
|
|
178
|
+
if not npu_op_name or not re.match(r'.+(?:for|back)ward\..+', npu_op_name):
|
|
179
|
+
return CompareConst.N_A
|
|
180
|
+
npu_op_name = npu_op_name.replace("Cell", "Module", 1)
|
|
38
181
|
if self.cell_mapping_dict:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
182
|
+
# get cell name & class name from op_name
|
|
183
|
+
# Cell.fc1.Dense.forward.0.input.0
|
|
184
|
+
cell_name = re.split(r'\.(?:for|back)ward\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
|
|
185
|
+
if cell_name in self.cell_mapping_dict:
|
|
186
|
+
npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
45
187
|
return npu_op_name
|
|
46
188
|
|
|
47
|
-
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
48
|
-
npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
|
|
49
|
-
npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
|
|
50
|
-
if self.cell_mapping is not None:
|
|
51
|
-
npu_op_name = self.process_cell_mapping(npu_op_name)
|
|
52
|
-
if self.api_mapping is not None:
|
|
53
|
-
npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
|
|
54
|
-
if isinstance(self.api_mapping, str):
|
|
55
|
-
npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new, bench_dict_new)
|
|
56
|
-
if target_dict:
|
|
57
|
-
bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
|
|
58
|
-
npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
|
|
59
|
-
struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
|
|
60
|
-
if not fuzzy_match:
|
|
61
|
-
return npu_op_name == bench_op_name and struct_match
|
|
62
|
-
is_match = True
|
|
63
|
-
try:
|
|
64
|
-
is_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
65
|
-
except Exception as err:
|
|
66
|
-
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
67
|
-
is_match = False
|
|
68
|
-
return is_match and struct_match
|
|
69
|
-
|
|
70
189
|
def read_npy_data(self, dir_path, file_name, load_pt_file=False):
|
|
190
|
+
if not file_name:
|
|
191
|
+
return None
|
|
71
192
|
data_path = os.path.join(dir_path, file_name)
|
|
72
193
|
if load_pt_file:
|
|
73
194
|
import torch
|
|
74
195
|
from msprobe.pytorch.common.utils import load_pt
|
|
75
|
-
data_value = load_pt(data_path).detach()
|
|
196
|
+
data_value = load_pt(data_path, True).detach()
|
|
76
197
|
if data_value.dtype == torch.bfloat16:
|
|
77
198
|
data_value = data_value.to(torch.float32)
|
|
78
199
|
data_value = data_value.numpy()
|
|
79
200
|
else:
|
|
80
201
|
data_value = load_npy(data_path)
|
|
81
|
-
return data_value
|
|
202
|
+
return data_value
|
|
82
203
|
|
|
83
|
-
def
|
|
84
|
-
for idx, _ in enumerate(npu_op_name):
|
|
85
|
-
npu_op_name[idx] = npu_op_name[idx].replace(target, para)
|
|
86
|
-
return npu_op_name
|
|
87
|
-
|
|
88
|
-
def process_internal_api_mapping(self, npu_op_name, bench_op_name):
|
|
204
|
+
def process_internal_api_mapping(self, npu_op_name):
|
|
89
205
|
# get api name & class name from op_name
|
|
90
206
|
# Functional.addcmul.0.forward.input.0
|
|
91
|
-
|
|
92
|
-
ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
|
|
93
|
-
pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
|
|
207
|
+
ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
|
|
94
208
|
class_name = ms_api_name.split(Const.SEP)[0]
|
|
95
209
|
if class_name == "Mint":
|
|
96
|
-
return
|
|
210
|
+
return npu_op_name.replace("Mint", "Torch")
|
|
97
211
|
elif class_name == "MintFunctional":
|
|
98
|
-
return
|
|
99
|
-
elif self.ms_to_pt_mapping.get(ms_api_name)
|
|
100
|
-
return
|
|
212
|
+
return npu_op_name.replace("MintFunctional", "Functional")
|
|
213
|
+
elif self.ms_to_pt_mapping.get(ms_api_name):
|
|
214
|
+
return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
|
|
101
215
|
else:
|
|
102
|
-
return npu_op_name
|
|
103
|
-
|
|
104
|
-
def remove_element(self, op_name, struct, summary, idx):
|
|
105
|
-
del op_name[idx]
|
|
106
|
-
del struct[idx]
|
|
107
|
-
del summary[idx]
|
|
216
|
+
return npu_op_name
|
|
108
217
|
|
|
109
218
|
def get_api_name(self, api_list):
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
new_npu_dict (dict): New NPU operation dictionary.
|
|
117
|
-
new_bench_dict (dict): New benchmark operation dictionary.
|
|
118
|
-
Returns:
|
|
119
|
-
tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
|
|
120
|
-
"""
|
|
121
|
-
npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
|
|
122
|
-
npu_struct_in, bench_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT), new_bench_dict.get(CompareConst.INPUT_STRUCT)
|
|
123
|
-
npu_struct_out, bench_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT), new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
124
|
-
npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
|
|
125
|
-
npu_in_len, bench_in_len, npu_out_len, bench_out_len = len(npu_struct_in), len(bench_struct_in), len(npu_struct_out), len(bench_struct_out)
|
|
126
|
-
ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
|
|
127
|
-
ms_api_name = self.get_api_name(ms_api_list)
|
|
128
|
-
pt_api_name = self.get_api_name(pt_api_list)
|
|
129
|
-
target_dict = {}
|
|
130
|
-
for api_dict in self.api_mapping_dict:
|
|
131
|
-
if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
|
|
132
|
-
ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
|
|
133
|
-
ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
|
|
134
|
-
if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
|
|
135
|
-
logger.warning("The user-defined mapping table is incorrect, make sure that the number of parameters is equal" )
|
|
136
|
-
break
|
|
137
|
-
ms_out_list = api_dict.get("ms_output", [])
|
|
138
|
-
for idx in reversed(range(npu_out_len)):
|
|
139
|
-
if idx not in ms_out_list:
|
|
140
|
-
del npu_struct_out[idx]
|
|
141
|
-
del npu_summary[idx + npu_in_len]
|
|
142
|
-
del npu_op_name[idx + npu_in_len]
|
|
143
|
-
pt_out_list = api_dict.get("pt_output", [])
|
|
144
|
-
for idx in reversed(range(bench_out_len)):
|
|
145
|
-
if idx not in pt_out_list:
|
|
146
|
-
del bench_struct_out[idx]
|
|
147
|
-
del bench_summary[idx + bench_in_len]
|
|
148
|
-
del bench_op_name[idx + bench_in_len]
|
|
149
|
-
ms_para_list = api_dict.get("ms_args", [])
|
|
150
|
-
for idx in reversed(range(npu_in_len)):
|
|
151
|
-
if idx not in ms_para_list:
|
|
152
|
-
self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
|
|
153
|
-
pt_para_list = api_dict.get("pt_args", [])
|
|
154
|
-
for idx in reversed(range(bench_in_len)):
|
|
155
|
-
if idx not in pt_para_list:
|
|
156
|
-
self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
|
|
157
|
-
npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
|
|
158
|
-
npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
|
|
159
|
-
target_dict = api_dict
|
|
160
|
-
break
|
|
161
|
-
if target_dict:
|
|
162
|
-
new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in, CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
|
|
163
|
-
new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
|
|
164
|
-
return new_npu_dict, new_bench_dict, target_dict
|
|
165
|
-
|
|
166
|
-
def para_sequence_update(self, npu_op_name, bench_op_name):
|
|
167
|
-
for idx, _ in enumerate(npu_op_name):
|
|
168
|
-
bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
|
|
169
|
-
if len(bench_op_name_list) != 0:
|
|
170
|
-
npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
|
|
171
|
-
return npu_op_name
|
|
219
|
+
try:
|
|
220
|
+
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
221
|
+
except IndexError as error:
|
|
222
|
+
logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
|
|
223
|
+
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
224
|
+
return api_name
|
|
172
225
|
|
|
173
|
-
def
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
226
|
+
def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
|
|
227
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
228
|
+
npu_json_data = load_json(npu_json_path)
|
|
229
|
+
bench_json_data = load_json(bench_json_path)
|
|
230
|
+
stack_json_data = load_json(stack_json_path)
|
|
231
|
+
|
|
232
|
+
npu_df = self.gen_data_df(npu_json_data, stack_json_data, dump_mode)
|
|
233
|
+
bench_df = self.gen_data_df(bench_json_data, stack_json_data, dump_mode)
|
|
234
|
+
if self.cell_mapping:
|
|
235
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
|
|
236
|
+
elif self.api_mapping:
|
|
237
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
|
|
238
|
+
if isinstance(self.api_mapping, str):
|
|
239
|
+
self.modify_compare_data_with_user_mapping(npu_df, bench_df)
|
|
240
|
+
else:
|
|
241
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
|
|
242
|
+
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
243
|
+
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
244
|
+
npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
|
|
245
|
+
bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
|
|
246
|
+
bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
|
|
247
|
+
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
|
|
248
|
+
how='outer')
|
|
249
|
+
match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
|
|
250
|
+
|
|
251
|
+
def gen_dtype_condition():
|
|
252
|
+
npu_dtype = match_result['dtype_x']
|
|
253
|
+
bench_dtype = match_result['dtype_y']
|
|
254
|
+
if self.cross_frame:
|
|
255
|
+
npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
|
|
256
|
+
return ((npu_dtype == bench_dtype) |
|
|
257
|
+
((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
|
|
258
|
+
((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
|
|
259
|
+
((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
|
|
260
|
+
((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
|
|
261
|
+
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
|
|
262
|
+
((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
|
|
263
|
+
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
|
|
264
|
+
((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
|
|
201
265
|
|
|
266
|
+
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
267
|
+
return MSComparator.make_result_df(match_result, stack_mode, dump_mode)
|
|
268
|
+
|
|
269
|
+
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
270
|
+
def get_api_indices_dict(op_name_df):
|
|
271
|
+
api_indices_dict = defaultdict(list)
|
|
272
|
+
for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
|
|
273
|
+
api = self.get_api_name(name.split(Const.SEP))
|
|
274
|
+
api_indices_dict[api].append(op_index)
|
|
275
|
+
return api_indices_dict
|
|
276
|
+
|
|
277
|
+
ms_api_indices_dict = get_api_indices_dict(npu_df)
|
|
278
|
+
pt_api_indices_dict = get_api_indices_dict(bench_df)
|
|
279
|
+
|
|
280
|
+
def gen_input_compare_key(pattern, term):
|
|
281
|
+
flag = True
|
|
282
|
+
for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
|
|
283
|
+
if op_name.split(pattern)[1].startswith(str(prefix)):
|
|
284
|
+
npu_df.loc[index, CompareConst.COMPARE_KEY] = (
|
|
285
|
+
op_name.replace(pattern + str(prefix),
|
|
286
|
+
pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
287
|
+
flag = False
|
|
288
|
+
return flag
|
|
289
|
+
|
|
290
|
+
for mapping_dict in self.api_mapping_dict:
|
|
291
|
+
if (len(mapping_dict.get('ms_args')) != len(mapping_dict.get('pt_args')) or
|
|
292
|
+
len(mapping_dict.get('ms_output')) != len(mapping_dict.get('pt_output'))):
|
|
293
|
+
logger.warning('The user-defined mapping table is incorrect,\
|
|
294
|
+
make sure that the number of parameters is equal')
|
|
295
|
+
continue
|
|
296
|
+
ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
|
|
297
|
+
if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
|
|
298
|
+
continue
|
|
299
|
+
for index in ms_api_indices_dict.get(ms_api):
|
|
300
|
+
op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
|
|
301
|
+
if CompareConst.INPUT_PATTERN in op_name:
|
|
302
|
+
is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
|
|
303
|
+
elif CompareConst.KWARGS_PATTERN in op_name:
|
|
304
|
+
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
305
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
306
|
+
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
307
|
+
else:
|
|
308
|
+
logger.error(f'Excepted op_name: {op_name}')
|
|
309
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
310
|
+
if is_abandoned:
|
|
311
|
+
npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
|
|
312
|
+
|
|
313
|
+
def gen_data_df(self, data_json, stack_json, dump_mode):
|
|
314
|
+
result = {
|
|
315
|
+
CompareConst.OP_NAME: [],
|
|
316
|
+
Const.DTYPE: [],
|
|
317
|
+
Const.SHAPE: [],
|
|
318
|
+
Const.SUMMARY: [],
|
|
319
|
+
'stack_info': []
|
|
320
|
+
}
|
|
321
|
+
if dump_mode == Const.ALL:
|
|
322
|
+
result['data_name'] = []
|
|
323
|
+
elif dump_mode == Const.MD5:
|
|
324
|
+
result[Const.MD5] = []
|
|
325
|
+
for data_name in data_json['data']:
|
|
326
|
+
check_op_str_pattern_valid(data_name)
|
|
327
|
+
merge_list = self.gen_merge_list(data_json, data_name, stack_json, dump_mode)
|
|
328
|
+
if not merge_list:
|
|
329
|
+
continue
|
|
330
|
+
for op_name in merge_list[CompareConst.OP_NAME]:
|
|
331
|
+
result[CompareConst.OP_NAME].append(op_name)
|
|
332
|
+
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
333
|
+
struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
|
|
334
|
+
else:
|
|
335
|
+
struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
|
|
336
|
+
result[Const.DTYPE].append(struct[0])
|
|
337
|
+
result[Const.SHAPE].append(struct[1])
|
|
338
|
+
if dump_mode == Const.MD5:
|
|
339
|
+
result[Const.MD5].append(struct[2])
|
|
340
|
+
result[Const.SUMMARY].append(merge_list[Const.SUMMARY].pop(0))
|
|
341
|
+
result['stack_info'].append(merge_list['stack_info'][0])
|
|
342
|
+
if dump_mode == Const.ALL:
|
|
343
|
+
result['data_name'].append(merge_list['data_name'].pop(0))
|
|
344
|
+
return pd.DataFrame(result)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def check_cross_framework(bench_json_path):
|
|
348
|
+
pattern = r'"data_name":\s*"[^"]+\.pt"'
|
|
349
|
+
with FileOpen(bench_json_path, 'r') as file:
|
|
350
|
+
for line in file:
|
|
351
|
+
if re.search(pattern, line):
|
|
352
|
+
return True
|
|
353
|
+
return False
|
|
354
|
+
|
|
355
|
+
|
|
202
356
|
def ms_compare(input_param, output_path, **kwargs):
|
|
203
357
|
try:
|
|
204
358
|
stack_mode = kwargs.get('stack_mode', False)
|
|
@@ -206,14 +360,21 @@ def ms_compare(input_param, output_path, **kwargs):
|
|
|
206
360
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
207
361
|
cell_mapping = kwargs.get('cell_mapping', None)
|
|
208
362
|
api_mapping = kwargs.get('api_mapping', None)
|
|
209
|
-
|
|
210
|
-
|
|
363
|
+
data_mapping = kwargs.get('data_mapping', None)
|
|
364
|
+
layer_mapping = kwargs.get('layer_mapping', None)
|
|
365
|
+
suffix = kwargs.get('suffix', '')
|
|
366
|
+
|
|
367
|
+
set_dump_path(input_param)
|
|
368
|
+
dump_mode = get_dump_mode(input_param)
|
|
369
|
+
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
211
370
|
create_directory(output_path)
|
|
212
|
-
check_compare_param(input_param, output_path,
|
|
371
|
+
check_compare_param(input_param, output_path, dump_mode)
|
|
213
372
|
except (CompareException, FileCheckException) as error:
|
|
214
373
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
215
374
|
raise CompareException(error.code) from error
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
375
|
+
if layer_mapping:
|
|
376
|
+
data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
|
|
377
|
+
is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
|
|
378
|
+
ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
|
|
379
|
+
ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, suffix=suffix,
|
|
380
|
+
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)
|