mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -1,33 +1,73 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import os
|
|
2
17
|
import re
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json, save_yaml, FileOpen
|
|
10
|
-
from msprobe.core.common.const import Const, CompareConst
|
|
11
|
-
from msprobe.core.common.log import logger
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
12
24
|
from msprobe.core.common.exceptions import FileCheckException
|
|
13
|
-
from msprobe.core.
|
|
14
|
-
from msprobe.core.
|
|
15
|
-
from msprobe.
|
|
16
|
-
|
|
25
|
+
from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml
|
|
26
|
+
from msprobe.core.common.log import logger
|
|
27
|
+
from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
|
|
28
|
+
check_op_str_pattern_valid, get_dump_mode, set_dump_path
|
|
29
|
+
from msprobe.core.compare.acc_compare import Comparator, ModeConfig
|
|
30
|
+
from msprobe.core.compare.check import dtype_mapping
|
|
31
|
+
from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
|
|
32
|
+
from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list
|
|
17
33
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
34
|
+
|
|
35
|
+
class MappingConfig:
|
|
36
|
+
def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
|
|
21
37
|
self.cell_mapping = cell_mapping
|
|
22
38
|
self.api_mapping = api_mapping
|
|
23
39
|
self.data_mapping = data_mapping
|
|
24
|
-
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class MSComparator(Comparator):
|
|
43
|
+
"""
|
|
44
|
+
用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
|
|
45
|
+
cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
|
|
46
|
+
api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
|
|
47
|
+
data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
|
|
48
|
+
is_cross_framework: 是否跨框架。
|
|
49
|
+
"""
|
|
50
|
+
def __init__(self, mode_config, mapping_config=None, is_cross_framework=False):
|
|
51
|
+
super().__init__(mode_config)
|
|
52
|
+
self.frame_name = MSComparator.__name__
|
|
53
|
+
|
|
54
|
+
self.stack_mode = mode_config.stack_mode
|
|
55
|
+
self.auto_analyze = mode_config.auto_analyze
|
|
56
|
+
self.fuzzy_match = mode_config.fuzzy_match
|
|
57
|
+
self.dump_mode = mode_config.dump_mode
|
|
58
|
+
|
|
59
|
+
if mapping_config:
|
|
60
|
+
self.cell_mapping = mapping_config.cell_mapping
|
|
61
|
+
self.api_mapping = mapping_config.api_mapping
|
|
62
|
+
self.data_mapping = mapping_config.data_mapping
|
|
63
|
+
|
|
64
|
+
if self.data_mapping:
|
|
25
65
|
self.cross_frame = is_cross_framework
|
|
26
66
|
else:
|
|
27
|
-
self.cross_frame = cell_mapping is not None or api_mapping is not None
|
|
67
|
+
self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None
|
|
28
68
|
self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
|
|
29
69
|
self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
|
|
30
|
-
if api_mapping is not None:
|
|
70
|
+
if self.api_mapping is not None:
|
|
31
71
|
self.ms_to_pt_mapping = self.load_internal_api()
|
|
32
72
|
|
|
33
73
|
if isinstance(self.data_mapping, str) or self.data_mapping is None:
|
|
@@ -38,9 +78,106 @@ class MSComparator(Comparator):
|
|
|
38
78
|
raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
|
|
39
79
|
f"{type(self.data_mapping)}")
|
|
40
80
|
|
|
81
|
+
def calc_accuracy(self, result_df, header):
|
|
82
|
+
condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
|
|
83
|
+
result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
|
|
84
|
+
result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
|
|
85
|
+
|
|
86
|
+
def calc_summary_diff(data_type: str):
|
|
87
|
+
def type_check(val):
|
|
88
|
+
check_series = pd.Series(False, index=val.index)
|
|
89
|
+
val_str = val.astype(str)
|
|
90
|
+
check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
|
|
91
|
+
return check_series
|
|
92
|
+
|
|
93
|
+
def get_number(val):
|
|
94
|
+
return pd.to_numeric(val.astype(str), errors='coerce')
|
|
95
|
+
|
|
96
|
+
ms_val = result_df['NPU ' + data_type]
|
|
97
|
+
pt_val = result_df['Bench ' + data_type]
|
|
98
|
+
diff_name = data_type.capitalize() + ' diff'
|
|
99
|
+
rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
|
|
100
|
+
condition_na = ~type_check(ms_val) | ~type_check(pt_val)
|
|
101
|
+
result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
|
|
102
|
+
result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
|
|
103
|
+
condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
|
|
104
|
+
condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
|
|
105
|
+
result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
|
|
106
|
+
condition_pt_zero = pt_val == 0
|
|
107
|
+
result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
|
|
108
|
+
condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
|
|
109
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
|
|
110
|
+
pt_val[condition_ref_err] * 100)
|
|
111
|
+
result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
|
|
112
|
+
.abs().astype(str) + '%')
|
|
113
|
+
magnitude = get_number(result_df[diff_name]).abs() / (
|
|
114
|
+
pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
|
|
115
|
+
return magnitude > CompareConst.MAGNITUDE
|
|
116
|
+
|
|
117
|
+
if self.dump_mode == Const.MD5:
|
|
118
|
+
condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
|
|
119
|
+
result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
|
|
120
|
+
result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
|
|
121
|
+
elif self.dump_mode == Const.SUMMARY:
|
|
122
|
+
warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
|
|
123
|
+
warning_flag = pd.DataFrame(warning_list).all()
|
|
124
|
+
result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
|
|
125
|
+
result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
|
|
126
|
+
result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
|
|
127
|
+
else:
|
|
128
|
+
fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
|
|
129
|
+
CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
|
|
130
|
+
CompareConst.ERROR_MESSAGE]
|
|
131
|
+
result_df.loc[~condition_no_bench, fill_cols] = ''
|
|
132
|
+
result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
|
|
133
|
+
return result_df[header]
|
|
134
|
+
|
|
135
|
+
def make_result_df(self, result):
|
|
136
|
+
header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
|
|
137
|
+
|
|
138
|
+
if self.stack_mode:
|
|
139
|
+
header.append(CompareConst.STACK)
|
|
140
|
+
if self.dump_mode == Const.ALL:
|
|
141
|
+
header.append(CompareConst.DATA_NAME)
|
|
142
|
+
result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
|
|
143
|
+
'op_name_y': CompareConst.BENCH_NAME,
|
|
144
|
+
'dtype_x': CompareConst.NPU_DTYPE,
|
|
145
|
+
'dtype_y': CompareConst.BENCH_DTYPE,
|
|
146
|
+
'shape_x': CompareConst.NPU_SHAPE,
|
|
147
|
+
'shape_y': CompareConst.BENCH_SHAPE,
|
|
148
|
+
'md5_x': CompareConst.NPU_MD5,
|
|
149
|
+
'md5_y': CompareConst.BENCH_MD5,
|
|
150
|
+
'data_name_x': CompareConst.DATA_NAME,
|
|
151
|
+
'stack_info_x': CompareConst.STACK}, inplace=True)
|
|
152
|
+
|
|
153
|
+
npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
|
|
154
|
+
bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
|
|
155
|
+
CompareConst.BENCH_NORM]
|
|
156
|
+
|
|
157
|
+
def set_summary(summary):
|
|
158
|
+
if summary == CompareConst.N_A:
|
|
159
|
+
return [CompareConst.N_A] * 4
|
|
160
|
+
summary_list = []
|
|
161
|
+
for i in summary:
|
|
162
|
+
if i is None:
|
|
163
|
+
summary_list.append(CompareConst.N_A)
|
|
164
|
+
elif str(i).lower() == 'nan':
|
|
165
|
+
summary_list.append(CompareConst.NAN)
|
|
166
|
+
else:
|
|
167
|
+
summary_list.append(i)
|
|
168
|
+
return summary_list
|
|
169
|
+
|
|
170
|
+
result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
|
|
171
|
+
result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
|
|
172
|
+
result_df = pd.DataFrame(columns=header)
|
|
173
|
+
for h in header:
|
|
174
|
+
if h in result.columns:
|
|
175
|
+
result_df[h] = result[h]
|
|
176
|
+
return self.calc_accuracy(result_df, header)
|
|
177
|
+
|
|
41
178
|
def load_internal_api(self):
|
|
42
179
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
43
|
-
yaml_path = os.path.join(cur_path,
|
|
180
|
+
yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
|
|
44
181
|
return load_yaml(yaml_path)
|
|
45
182
|
|
|
46
183
|
def load_mapping_file(self, mapping_file):
|
|
@@ -51,42 +188,23 @@ class MSComparator(Comparator):
|
|
|
51
188
|
return mapping_dict
|
|
52
189
|
|
|
53
190
|
def process_cell_mapping(self, npu_op_name):
|
|
54
|
-
|
|
191
|
+
if not npu_op_name:
|
|
192
|
+
return CompareConst.N_A
|
|
193
|
+
param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
|
|
194
|
+
if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
|
|
195
|
+
return CompareConst.N_A
|
|
196
|
+
npu_op_name = npu_op_name.replace("Cell", "Module", 1)
|
|
55
197
|
if self.cell_mapping_dict:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
198
|
+
# get cell name & class name from op_name
|
|
199
|
+
# Cell.fc1.Dense.forward.0.input.0
|
|
200
|
+
cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
|
|
201
|
+
if cell_name in self.cell_mapping_dict:
|
|
202
|
+
npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
|
|
62
203
|
return npu_op_name
|
|
63
204
|
|
|
64
|
-
def check_op(self, npu_dict, bench_dict, fuzzy_match):
|
|
65
|
-
npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
|
|
66
|
-
npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
|
|
67
|
-
if self.cell_mapping is not None:
|
|
68
|
-
npu_op_name = self.process_cell_mapping(npu_op_name)
|
|
69
|
-
if self.api_mapping is not None:
|
|
70
|
-
npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
|
|
71
|
-
if isinstance(self.api_mapping, str):
|
|
72
|
-
npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
|
|
73
|
-
bench_dict_new)
|
|
74
|
-
if target_dict:
|
|
75
|
-
bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
|
|
76
|
-
npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
|
|
77
|
-
bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
|
|
78
|
-
struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
|
|
79
|
-
if not fuzzy_match:
|
|
80
|
-
return npu_op_name == bench_op_name and struct_match
|
|
81
|
-
is_match = True
|
|
82
|
-
try:
|
|
83
|
-
is_match = fuzzy_check_op(npu_op_name, bench_op_name)
|
|
84
|
-
except Exception as err:
|
|
85
|
-
logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
|
|
86
|
-
is_match = False
|
|
87
|
-
return is_match and struct_match
|
|
88
|
-
|
|
89
205
|
def read_npy_data(self, dir_path, file_name, load_pt_file=False):
|
|
206
|
+
if not file_name:
|
|
207
|
+
return None
|
|
90
208
|
data_path = os.path.join(dir_path, file_name)
|
|
91
209
|
if load_pt_file:
|
|
92
210
|
import torch
|
|
@@ -96,35 +214,23 @@ class MSComparator(Comparator):
|
|
|
96
214
|
data_value = data_value.to(torch.float32)
|
|
97
215
|
data_value = data_value.numpy()
|
|
98
216
|
else:
|
|
99
|
-
data_value = load_npy(data_path)
|
|
100
|
-
return data_value
|
|
217
|
+
data_value = load_npy(data_path)
|
|
218
|
+
return data_value
|
|
101
219
|
|
|
102
|
-
def
|
|
103
|
-
for idx, _ in enumerate(npu_op_name):
|
|
104
|
-
npu_op_name[idx] = npu_op_name[idx].replace(target, para)
|
|
105
|
-
return npu_op_name
|
|
106
|
-
|
|
107
|
-
def process_internal_api_mapping(self, npu_op_name, bench_op_name):
|
|
220
|
+
def process_internal_api_mapping(self, npu_op_name):
|
|
108
221
|
# get api name & class name from op_name
|
|
109
222
|
# Functional.addcmul.0.forward.input.0
|
|
110
|
-
|
|
111
|
-
ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
|
|
112
|
-
pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
|
|
223
|
+
ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
|
|
113
224
|
class_name = ms_api_name.split(Const.SEP)[0]
|
|
114
225
|
if class_name == "Mint":
|
|
115
|
-
return
|
|
226
|
+
return npu_op_name.replace("Mint", "Torch")
|
|
116
227
|
elif class_name == "MintFunctional":
|
|
117
|
-
return
|
|
118
|
-
elif self.ms_to_pt_mapping.get(ms_api_name)
|
|
119
|
-
return
|
|
228
|
+
return npu_op_name.replace("MintFunctional", "Functional")
|
|
229
|
+
elif self.ms_to_pt_mapping.get(ms_api_name):
|
|
230
|
+
return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
|
|
120
231
|
else:
|
|
121
232
|
return npu_op_name
|
|
122
|
-
|
|
123
|
-
def remove_element(self, op_name, struct, summary, idx):
|
|
124
|
-
del op_name[idx]
|
|
125
|
-
del struct[idx]
|
|
126
|
-
del summary[idx]
|
|
127
|
-
|
|
233
|
+
|
|
128
234
|
def get_api_name(self, api_list):
|
|
129
235
|
try:
|
|
130
236
|
api_name = api_list[0] + Const.SEP + api_list[1]
|
|
@@ -132,184 +238,147 @@ class MSComparator(Comparator):
|
|
|
132
238
|
logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
|
|
133
239
|
raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
|
|
134
240
|
return api_name
|
|
135
|
-
|
|
136
|
-
def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
|
|
137
|
-
"""
|
|
138
|
-
Transform user mapping API based on new NPU and benchmark dictionaries.
|
|
139
|
-
Parameters:
|
|
140
|
-
new_npu_dict (dict): New NPU operation dictionary.
|
|
141
|
-
new_bench_dict (dict): New benchmark operation dictionary.
|
|
142
|
-
Returns:
|
|
143
|
-
tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
|
|
144
|
-
"""
|
|
145
|
-
npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
|
|
146
|
-
npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
|
|
147
|
-
bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
|
|
148
|
-
npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
149
|
-
bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
|
|
150
|
-
npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
|
|
151
|
-
npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
|
|
152
|
-
npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
|
|
153
|
-
ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
|
|
154
|
-
ms_api_name = self.get_api_name(ms_api_list)
|
|
155
|
-
pt_api_name = self.get_api_name(pt_api_list)
|
|
156
|
-
target_dict = {}
|
|
157
|
-
for api_dict in self.api_mapping_dict:
|
|
158
|
-
if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
|
|
159
|
-
ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
|
|
160
|
-
ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
|
|
161
|
-
if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
|
|
162
|
-
logger.warning("The user-defined mapping table is incorrect,\
|
|
163
|
-
make sure that the number of parameters is equal")
|
|
164
|
-
break
|
|
165
|
-
ms_out_list = api_dict.get("ms_output", [])
|
|
166
|
-
for idx in reversed(range(npu_out_len)):
|
|
167
|
-
if idx not in ms_out_list:
|
|
168
|
-
del npu_struct_out[idx]
|
|
169
|
-
if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
|
|
170
|
-
del npu_summary[idx + npu_in_len]
|
|
171
|
-
del npu_op_name[idx + npu_in_len]
|
|
172
|
-
pt_out_list = api_dict.get("pt_output", [])
|
|
173
|
-
for idx in reversed(range(bench_out_len)):
|
|
174
|
-
if idx not in pt_out_list:
|
|
175
|
-
del bench_struct_out[idx]
|
|
176
|
-
if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
|
|
177
|
-
del bench_summary[idx + bench_in_len]
|
|
178
|
-
del bench_op_name[idx + bench_in_len]
|
|
179
|
-
ms_para_list = api_dict.get("ms_args", [])
|
|
180
|
-
for idx in reversed(range(npu_in_len)):
|
|
181
|
-
if idx not in ms_para_list:
|
|
182
|
-
self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
|
|
183
|
-
pt_para_list = api_dict.get("pt_args", [])
|
|
184
|
-
for idx in reversed(range(bench_in_len)):
|
|
185
|
-
if idx not in pt_para_list:
|
|
186
|
-
self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
|
|
187
|
-
npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
|
|
188
|
-
npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
|
|
189
|
-
target_dict = api_dict
|
|
190
|
-
break
|
|
191
|
-
if target_dict:
|
|
192
|
-
new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
|
|
193
|
-
CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
|
|
194
|
-
new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
|
|
195
|
-
CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
|
|
196
|
-
return new_npu_dict, new_bench_dict, target_dict
|
|
197
|
-
|
|
198
|
-
def para_sequence_update(self, npu_op_name, bench_op_name):
|
|
199
|
-
for idx, _ in enumerate(npu_op_name):
|
|
200
|
-
bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
|
|
201
|
-
if len(bench_op_name_list) != 0:
|
|
202
|
-
npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
|
|
203
|
-
return npu_op_name
|
|
204
241
|
|
|
205
|
-
def
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
242
|
+
def compare_process(self, file_lists):
|
|
243
|
+
npu_json_path, bench_json_path, stack_json_path = file_lists
|
|
244
|
+
npu_json_data = load_json(npu_json_path)
|
|
245
|
+
bench_json_data = load_json(bench_json_path)
|
|
246
|
+
stack_json_data = load_json(stack_json_path) if self.stack_mode else None
|
|
247
|
+
|
|
248
|
+
npu_df = self.gen_data_df(npu_json_data, stack_json_data)
|
|
249
|
+
bench_df = self.gen_data_df(bench_json_data, stack_json_data)
|
|
250
|
+
if self.cell_mapping:
|
|
251
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
|
|
252
|
+
elif self.api_mapping:
|
|
253
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
|
|
254
|
+
if isinstance(self.api_mapping, str):
|
|
255
|
+
self.modify_compare_data_with_user_mapping(npu_df, bench_df)
|
|
256
|
+
else:
|
|
257
|
+
npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
|
|
258
|
+
npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
259
|
+
bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
|
|
260
|
+
npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
|
|
261
|
+
bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
|
|
262
|
+
bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
|
|
263
|
+
match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
|
|
264
|
+
how='outer')
|
|
265
|
+
match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
|
|
266
|
+
|
|
267
|
+
def gen_dtype_condition():
|
|
268
|
+
npu_dtype = match_result['dtype_x']
|
|
269
|
+
bench_dtype = match_result['dtype_y']
|
|
270
|
+
if self.cross_frame:
|
|
271
|
+
npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
|
|
272
|
+
return ((npu_dtype == bench_dtype) |
|
|
273
|
+
((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
|
|
274
|
+
((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
|
|
275
|
+
((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
|
|
276
|
+
((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
|
|
277
|
+
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
|
|
278
|
+
((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
|
|
279
|
+
((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
|
|
280
|
+
((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
|
|
281
|
+
|
|
282
|
+
match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
|
|
283
|
+
return self.make_result_df(match_result)
|
|
284
|
+
|
|
285
|
+
def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
|
|
286
|
+
def get_api_indices_dict(op_name_df):
|
|
287
|
+
api_indices_dict = defaultdict(list)
|
|
288
|
+
for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
|
|
289
|
+
api = self.get_api_name(name.split(Const.SEP))
|
|
290
|
+
api_indices_dict[api].append(op_index)
|
|
291
|
+
return api_indices_dict
|
|
292
|
+
|
|
293
|
+
ms_api_indices_dict = get_api_indices_dict(npu_df)
|
|
294
|
+
pt_api_indices_dict = get_api_indices_dict(bench_df)
|
|
295
|
+
|
|
296
|
+
def gen_input_compare_key(pattern, term):
|
|
297
|
+
flag = True
|
|
298
|
+
for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
|
|
299
|
+
if op_name.split(pattern)[1].startswith(str(prefix)):
|
|
300
|
+
npu_df.loc[index, CompareConst.COMPARE_KEY] = (
|
|
301
|
+
op_name.replace(pattern + str(prefix),
|
|
302
|
+
pattern + str(mapping_dict.get(f'pt_{term}')[i])))
|
|
303
|
+
flag = False
|
|
304
|
+
return flag
|
|
305
|
+
|
|
306
|
+
for mapping_dict in self.api_mapping_dict:
|
|
307
|
+
keys_to_compare = [
|
|
308
|
+
('ms_args', 'pt_args'),
|
|
309
|
+
('ms_output', 'pt_output'),
|
|
310
|
+
('ms_parameters', 'pt_parameters'),
|
|
311
|
+
('ms_parameters_grad', 'pt_parameters_grad'),
|
|
312
|
+
]
|
|
313
|
+
if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare):
|
|
314
|
+
logger.warning('The user-defined mapping table is incorrect,\
|
|
315
|
+
make sure that the number of parameters is equal')
|
|
316
|
+
continue
|
|
317
|
+
|
|
318
|
+
ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
|
|
319
|
+
if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
|
|
320
|
+
continue
|
|
321
|
+
for index in ms_api_indices_dict.get(ms_api):
|
|
322
|
+
op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
|
|
323
|
+
if CompareConst.INPUT_PATTERN in op_name:
|
|
324
|
+
is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
|
|
325
|
+
elif CompareConst.KWARGS_PATTERN in op_name:
|
|
326
|
+
is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
|
|
327
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
328
|
+
is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
|
|
329
|
+
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
330
|
+
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
|
|
331
|
+
elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
|
|
332
|
+
is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
|
|
333
|
+
else:
|
|
334
|
+
logger.error(f'Excepted op_name: {op_name}')
|
|
335
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
336
|
+
if is_abandoned:
|
|
337
|
+
npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
|
|
338
|
+
|
|
339
|
+
def gen_data_df(self, data_json, stack_json_data):
|
|
340
|
+
result = {
|
|
341
|
+
CompareConst.OP_NAME: [],
|
|
342
|
+
Const.DTYPE: [],
|
|
343
|
+
Const.SHAPE: [],
|
|
344
|
+
Const.SUMMARY: [],
|
|
345
|
+
'stack_info': []
|
|
346
|
+
}
|
|
347
|
+
if self.dump_mode == Const.ALL:
|
|
348
|
+
result['data_name'] = []
|
|
349
|
+
elif self.dump_mode == Const.MD5:
|
|
350
|
+
result[Const.MD5] = []
|
|
351
|
+
for data_name in data_json['data']:
|
|
352
|
+
check_op_str_pattern_valid(data_name)
|
|
353
|
+
merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
|
|
354
|
+
if not merge_list:
|
|
271
355
|
continue
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
for
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
backward_data = []
|
|
300
|
-
mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
|
|
301
|
-
for map_value in mapping_list:
|
|
302
|
-
npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
|
|
303
|
-
bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
|
|
304
|
-
inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
|
|
305
|
-
outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
|
|
306
|
-
backward_data.extend(inputs_zip)
|
|
307
|
-
backward_data.extend(outputs_zip)
|
|
308
|
-
|
|
309
|
-
kernel_data = forward_data + backward_data
|
|
310
|
-
result = {key: value for key, value in kernel_data if key is not None}
|
|
311
|
-
|
|
312
|
-
return result
|
|
356
|
+
|
|
357
|
+
op_name_list = merge_list.get(CompareConst.OP_NAME)
|
|
358
|
+
summary_list = merge_list.get(Const.SUMMARY)
|
|
359
|
+
data_name_list = merge_list.get('data_name')
|
|
360
|
+
op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
|
|
361
|
+
summary_list,
|
|
362
|
+
data_name_list)
|
|
363
|
+
for op_name in op_name_reorder:
|
|
364
|
+
result[CompareConst.OP_NAME].append(op_name)
|
|
365
|
+
if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
|
|
366
|
+
struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
|
|
367
|
+
elif CompareConst.OUTPUT_PATTERN in op_name:
|
|
368
|
+
struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
|
|
369
|
+
elif CompareConst.PARAMS_PATTERN in op_name:
|
|
370
|
+
struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
|
|
371
|
+
else:
|
|
372
|
+
struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
|
|
373
|
+
result[Const.DTYPE].append(struct[0])
|
|
374
|
+
result[Const.SHAPE].append(struct[1])
|
|
375
|
+
if self.dump_mode == Const.MD5:
|
|
376
|
+
result[Const.MD5].append(struct[2])
|
|
377
|
+
result[Const.SUMMARY].append(summary_reorder.pop(0))
|
|
378
|
+
result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None)
|
|
379
|
+
if self.dump_mode == Const.ALL:
|
|
380
|
+
result['data_name'].append(data_name_reorder.pop(0))
|
|
381
|
+
return pd.DataFrame(result)
|
|
313
382
|
|
|
314
383
|
|
|
315
384
|
def check_cross_framework(bench_json_path):
|
|
@@ -323,35 +392,31 @@ def check_cross_framework(bench_json_path):
|
|
|
323
392
|
|
|
324
393
|
def ms_compare(input_param, output_path, **kwargs):
|
|
325
394
|
try:
|
|
326
|
-
stack_mode = kwargs.get('stack_mode', False)
|
|
327
395
|
auto_analyze = kwargs.get('auto_analyze', True)
|
|
328
396
|
fuzzy_match = kwargs.get('fuzzy_match', False)
|
|
329
397
|
cell_mapping = kwargs.get('cell_mapping', None)
|
|
330
398
|
api_mapping = kwargs.get('api_mapping', None)
|
|
331
399
|
data_mapping = kwargs.get('data_mapping', None)
|
|
332
400
|
layer_mapping = kwargs.get('layer_mapping', None)
|
|
401
|
+
suffix = kwargs.get('suffix', '')
|
|
333
402
|
|
|
334
|
-
|
|
403
|
+
set_dump_path(input_param)
|
|
404
|
+
dump_mode = get_dump_mode(input_param)
|
|
405
|
+
if 'stack_json_path' in input_param:
|
|
406
|
+
stack_mode = kwargs.get('stack_mode', False)
|
|
407
|
+
else:
|
|
408
|
+
stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
|
|
335
409
|
check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
|
|
336
410
|
create_directory(output_path)
|
|
337
|
-
check_compare_param(input_param, output_path,
|
|
411
|
+
check_compare_param(input_param, output_path, dump_mode, stack_mode)
|
|
338
412
|
except (CompareException, FileCheckException) as error:
|
|
339
413
|
logger.error('Compare failed. Please check the arguments and do it again!')
|
|
340
414
|
raise CompareException(error.code) from error
|
|
341
415
|
if layer_mapping:
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
data_mapping_name = add_time_with_yaml(f"data_mapping")
|
|
351
|
-
data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
|
|
352
|
-
save_yaml(data_mapping_path, data_mapping)
|
|
353
|
-
is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
|
|
354
|
-
ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
|
|
355
|
-
ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
|
|
356
|
-
auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
|
|
357
|
-
md5_compare=md5_compare)
|
|
416
|
+
data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
|
|
417
|
+
|
|
418
|
+
mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
|
|
419
|
+
mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping)
|
|
420
|
+
is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
|
|
421
|
+
ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework)
|
|
422
|
+
ms_comparator.compare_core(input_param, output_path, suffix=suffix)
|