mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from typing import ClassVar, Dict, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import yaml
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.file_utils import save_yaml
|
|
25
|
+
from msprobe.core.common.log import logger
|
|
26
|
+
from msprobe.core.common.utils import CompareException, add_time_with_yaml
|
|
27
|
+
from msprobe.core.compare.layer_mapping.postprocess_pass import postprocess_pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class DumpDataItem:
|
|
32
|
+
framework: str
|
|
33
|
+
data_name: Optional[str] = None
|
|
34
|
+
api_type: Optional[str] = None
|
|
35
|
+
api_name: Optional[str] = None
|
|
36
|
+
type_name: Optional[str] = None
|
|
37
|
+
full_scope: str = ""
|
|
38
|
+
layer_scope: str = ""
|
|
39
|
+
stack_scope: str = ""
|
|
40
|
+
frame_stack_scope: str = ""
|
|
41
|
+
user_stack_scope: str = ""
|
|
42
|
+
construct_scope: str = ""
|
|
43
|
+
scope_direction: Optional[str] = None
|
|
44
|
+
scope_id: Optional[int] = None
|
|
45
|
+
state: str = ""
|
|
46
|
+
|
|
47
|
+
# 类变量使用 ClassVar
|
|
48
|
+
layernames: ClassVar[set] = {Const.CELL, Const.MODULE}
|
|
49
|
+
framework2stack_sign: ClassVar[Dict[str, Tuple[str, str]]] = {
|
|
50
|
+
Const.MS_FRAMEWORK: ("Template", "construct"),
|
|
51
|
+
Const.PT_FRAMEWORK: ("Template", r"in (for|back)ward,")
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def check_stack_valid(stack_info):
|
|
56
|
+
if stack_info is not None:
|
|
57
|
+
if not isinstance(stack_info, list):
|
|
58
|
+
logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
|
|
59
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
60
|
+
for stack in stack_info:
|
|
61
|
+
if not isinstance(stack, str):
|
|
62
|
+
logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
|
|
63
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
64
|
+
|
|
65
|
+
def set(self, data_name: str, construct_info: str, stack_info: str) -> None:
|
|
66
|
+
self.set_name(data_name)
|
|
67
|
+
self.set_layer_scope(construct_info)
|
|
68
|
+
self.set_stack_scope(stack_info)
|
|
69
|
+
self.set_full_scope()
|
|
70
|
+
|
|
71
|
+
def set_name(self, data_name: str) -> None:
|
|
72
|
+
self.data_name = data_name
|
|
73
|
+
data_name_list = data_name.split(Const.SEP)
|
|
74
|
+
if not data_name_list or len(data_name_list) < abs(Const.LAYER_NAME_INDEX):
|
|
75
|
+
logger.error(
|
|
76
|
+
f"The dump data does not comply with the format specification and "
|
|
77
|
+
f"must contain no less than four fields. "
|
|
78
|
+
f"The current data is {data_name}"
|
|
79
|
+
)
|
|
80
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
81
|
+
|
|
82
|
+
if data_name_list[Const.LAST_INDEX] == Const.PARAMS_GRAD:
|
|
83
|
+
self.api_type = Const.PARAMS_GRAD
|
|
84
|
+
self.api_name = data_name_list[Const.PARAMS_GRAD_NAME_INDEX]
|
|
85
|
+
self.type_name = data_name_list[Const.PARAMS_GRAD_TYPE_NAME_INDEX]
|
|
86
|
+
self.state = Const.PARAMS_GRAD
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
self.api_type = data_name_list[Const.API_TYPE_INDEX]
|
|
90
|
+
self.type_name = data_name_list[Const.TYPE_NAME_INDEX]
|
|
91
|
+
if self.api_type in self.layernames:
|
|
92
|
+
self.api_name = data_name_list[Const.LAYER_NAME_INDEX]
|
|
93
|
+
self.state = data_name_list[Const.SCOPE_DIRECTION_INDEX]
|
|
94
|
+
else:
|
|
95
|
+
self.api_name = self.type_name
|
|
96
|
+
self.state = data_name_list[Const.LAST_INDEX]
|
|
97
|
+
|
|
98
|
+
def set_layer_scope(self, construct_info: str) -> None:
|
|
99
|
+
self.construct_scope = construct_info
|
|
100
|
+
if self.api_type in self.layernames:
|
|
101
|
+
# remove api name
|
|
102
|
+
data_list = self.data_name.split(Const.SEP)
|
|
103
|
+
data_list = data_list[:Const.LAYER_NAME_INDEX] + data_list[Const.TYPE_NAME_INDEX:]
|
|
104
|
+
elif self.api_type == Const.PARAMS_GRAD:
|
|
105
|
+
data_list = self.data_name.split(Const.SEP)
|
|
106
|
+
elif construct_info:
|
|
107
|
+
data_list = construct_info.split(Const.SEP)
|
|
108
|
+
else:
|
|
109
|
+
data_list = []
|
|
110
|
+
|
|
111
|
+
if data_list:
|
|
112
|
+
self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
|
|
113
|
+
else:
|
|
114
|
+
self.layer_scope = Const.TOP_LAYER
|
|
115
|
+
if construct_info:
|
|
116
|
+
construct_list = construct_info.split(Const.SEP)
|
|
117
|
+
if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
|
|
118
|
+
logger.error(
|
|
119
|
+
f"The construct data does not comply with the format specification and "
|
|
120
|
+
f"must contain no less than four fields. "
|
|
121
|
+
f"The current data is {construct_info}"
|
|
122
|
+
)
|
|
123
|
+
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
124
|
+
self.scope_id = construct_list[Const.SCOPE_ID_INDEX]
|
|
125
|
+
self.scope_direction = construct_list[Const.SCOPE_DIRECTION_INDEX]
|
|
126
|
+
|
|
127
|
+
def set_stack_scope(self, stack_info: str) -> None:
|
|
128
|
+
# Cell/Module has no stack info
|
|
129
|
+
if self.api_type in self.layernames:
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
if self.api_type in Const.DATA_TYPE_SKIP_LIST or not stack_info:
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
start_sign, end_sign = self.framework2stack_sign.get(self.framework)
|
|
136
|
+
self.check_stack_valid(stack_info)
|
|
137
|
+
start_pos, end_pos = find_regard_scope(stack_info, start_sign, end_sign)
|
|
138
|
+
# 获取指定范围的代码
|
|
139
|
+
regard_scope = stack_info[start_pos + 1:end_pos]
|
|
140
|
+
frame_func_stack_list, user_func_stack_list = find_stack_func_list(regard_scope)
|
|
141
|
+
self.frame_stack_scope = Const.SEP.join(frame_func_stack_list)
|
|
142
|
+
self.user_stack_scope = Const.SEP.join(user_func_stack_list)
|
|
143
|
+
|
|
144
|
+
def set_full_scope(self, use_user_func_scope=False, use_frame_func_scope=True) -> None:
|
|
145
|
+
scope_list = [self.layer_scope]
|
|
146
|
+
if use_user_func_scope and self.user_stack_scope:
|
|
147
|
+
scope_list.append(self.user_stack_scope)
|
|
148
|
+
if use_frame_func_scope and self.frame_stack_scope:
|
|
149
|
+
scope_list.append(self.frame_stack_scope)
|
|
150
|
+
scope_list.append(self.api_name)
|
|
151
|
+
self.full_scope = Const.SEP.join(scope_list)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def find_regard_scope(lines, start_sign, end_sign):
|
|
155
|
+
# 找出 start_pos 和 end_pos
|
|
156
|
+
start_pos = -1
|
|
157
|
+
end_pos = len(lines)
|
|
158
|
+
for idx, ii in enumerate(lines):
|
|
159
|
+
if re.search(start_sign, ii):
|
|
160
|
+
start_pos = idx
|
|
161
|
+
elif start_pos >= 0 and re.search(end_sign, ii):
|
|
162
|
+
end_pos = idx
|
|
163
|
+
break
|
|
164
|
+
return start_pos, end_pos
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def find_stack_func_list(lines, record_user=True):
|
|
168
|
+
res_list = []
|
|
169
|
+
user_stack = []
|
|
170
|
+
frame_stack = None
|
|
171
|
+
no_entrance = True
|
|
172
|
+
for line in lines:
|
|
173
|
+
ele_list = line.split(Const.COMMA)
|
|
174
|
+
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
175
|
+
# if framework func line and no framework entrance found yet
|
|
176
|
+
if any(ii in file_ele for ii in Const.FRAME_FILE_LIST) and no_entrance:
|
|
177
|
+
frame_stack = line # Update the last target index
|
|
178
|
+
else:
|
|
179
|
+
if record_user:
|
|
180
|
+
user_stack.append(line)
|
|
181
|
+
no_entrance = False
|
|
182
|
+
|
|
183
|
+
# Check if the last string in the list contains target str
|
|
184
|
+
if frame_stack and no_entrance:
|
|
185
|
+
no_entrance = False
|
|
186
|
+
|
|
187
|
+
# 过滤和处理 regard_scope
|
|
188
|
+
frame_func = get_stack_in_lines([frame_stack])
|
|
189
|
+
user_func = get_stack_in_lines(user_stack)
|
|
190
|
+
return (frame_func, user_func)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_stack_in_lines(simplified: List[str]):
|
|
194
|
+
res_list = []
|
|
195
|
+
if not simplified:
|
|
196
|
+
return res_list
|
|
197
|
+
for line in simplified:
|
|
198
|
+
if not line:
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
ele_list = line.split(Const.COMMA)
|
|
202
|
+
file_ele = ele_list[Const.STACK_FILE_INDEX]
|
|
203
|
+
if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
func_ele = ele_list[Const.STACK_FUNC_INDEX]
|
|
207
|
+
if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
|
|
208
|
+
continue
|
|
209
|
+
|
|
210
|
+
in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
|
|
211
|
+
|
|
212
|
+
res_list.append(in_func_name)
|
|
213
|
+
|
|
214
|
+
reversed_list = res_list[::-1]
|
|
215
|
+
return reversed_list
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def dumpdata_representer(dumper, data):
|
|
219
|
+
d = deepcopy(data.__dict__)
|
|
220
|
+
d.pop("data_name")
|
|
221
|
+
return dumper.represent_dict(d)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def get_dump_data_items(dump, stack, construct, framework, output_path=None):
|
|
225
|
+
if not stack or not construct:
|
|
226
|
+
return []
|
|
227
|
+
name2item = {}
|
|
228
|
+
data_items = []
|
|
229
|
+
|
|
230
|
+
dump_data = dump.get("data", {})
|
|
231
|
+
for data_name in dump_data:
|
|
232
|
+
code_info = stack.get(data_name, None)
|
|
233
|
+
parent_info = construct.get(data_name, None)
|
|
234
|
+
data_item = DumpDataItem(framework)
|
|
235
|
+
data_item.set(data_name, parent_info, code_info)
|
|
236
|
+
name2item[data_name] = data_item
|
|
237
|
+
data_items.append(data_item)
|
|
238
|
+
|
|
239
|
+
postprocess_pass(data_items, name2item)
|
|
240
|
+
|
|
241
|
+
if output_path:
|
|
242
|
+
yaml.add_representer(DumpDataItem, dumpdata_representer)
|
|
243
|
+
file_name = add_time_with_yaml(f"{framework}_data")
|
|
244
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
245
|
+
save_yaml(file_path, name2item)
|
|
246
|
+
return data_items
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
|
|
19
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
20
|
+
from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
|
|
21
|
+
from msprobe.core.common.utils import (add_time_with_yaml,
|
|
22
|
+
detect_framework_by_dump_json,
|
|
23
|
+
get_stack_construct_by_dump_json_path)
|
|
24
|
+
from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
|
|
25
|
+
from msprobe.core.compare.utils import read_op, reorder_op_name_list
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LayerTrie:
|
|
30
|
+
def __init__(self, type_name, framework=None):
|
|
31
|
+
self.type_name = type_name
|
|
32
|
+
self.data_items = defaultdict(list)
|
|
33
|
+
self.children = {}
|
|
34
|
+
self.framework = framework
|
|
35
|
+
|
|
36
|
+
def __repr__(self):
|
|
37
|
+
data_nums = [{k: len(v)} for k, v in self.data_items.items()]
|
|
38
|
+
return f"Layer(type_name={self.type_name}, data_number={data_nums})"
|
|
39
|
+
|
|
40
|
+
def get(self, name):
|
|
41
|
+
return self.children.get(name)
|
|
42
|
+
|
|
43
|
+
def insert(self, data_item):
|
|
44
|
+
parts = data_item.full_scope.split(Const.SEP)
|
|
45
|
+
node = self
|
|
46
|
+
scope_name_list = parts[Const.RIGHT_MOVE_INDEX:]
|
|
47
|
+
|
|
48
|
+
for name in scope_name_list:
|
|
49
|
+
if name not in node.children:
|
|
50
|
+
node.children[name] = LayerTrie(name, data_item.framework)
|
|
51
|
+
node = node.children[name]
|
|
52
|
+
node.data_items[data_item.state].append(data_item)
|
|
53
|
+
node.type_name = data_item.type_name
|
|
54
|
+
|
|
55
|
+
def query_data(self, scope, state, index, default_value=None):
|
|
56
|
+
parts = scope.split(Const.SEP)
|
|
57
|
+
node = self
|
|
58
|
+
scope_name_list = parts[1:]
|
|
59
|
+
|
|
60
|
+
for name in scope_name_list:
|
|
61
|
+
if name not in node.children:
|
|
62
|
+
return default_value
|
|
63
|
+
node = node.children[name]
|
|
64
|
+
if index >= len(node.data_items[state]):
|
|
65
|
+
return default_value
|
|
66
|
+
return node.data_items[state][index]
|
|
67
|
+
|
|
68
|
+
def save_to_yaml(self, output_path):
|
|
69
|
+
result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
|
|
70
|
+
file_name = add_time_with_yaml(f"{self.framework}_tree")
|
|
71
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
72
|
+
save_yaml(file_path, result)
|
|
73
|
+
|
|
74
|
+
def convert_to_dict(self, node):
|
|
75
|
+
result = {}
|
|
76
|
+
result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
|
|
77
|
+
for child_key, child_node in node.children.items():
|
|
78
|
+
key = f"{child_key} @ {child_node}"
|
|
79
|
+
result[key] = self.convert_to_dict(child_node)
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def convert_scope(layer_trie, data_item, mapping=None):
|
|
84
|
+
if not mapping:
|
|
85
|
+
mapping = {}
|
|
86
|
+
new_scope = Const.TOP_LAYER
|
|
87
|
+
scope_list = data_item.full_scope.split(Const.SEP)
|
|
88
|
+
cur_node = layer_trie
|
|
89
|
+
|
|
90
|
+
idx = 0
|
|
91
|
+
while idx < len(scope_list) - 1:
|
|
92
|
+
child_name = scope_list[idx + 1]
|
|
93
|
+
type_name = cur_node.type_name
|
|
94
|
+
prefix_mapping = mapping.get(type_name, {})
|
|
95
|
+
mapping_list = prefix_mapping.get(child_name, [])
|
|
96
|
+
mapping_list.append((child_name, child_name, 1))
|
|
97
|
+
step = 1
|
|
98
|
+
for origin, target, level in mapping_list:
|
|
99
|
+
if Const.SEP.join(scope_list[idx + 1: idx + level + 1]) == origin:
|
|
100
|
+
new_scope = new_scope + Const.SEP + target
|
|
101
|
+
step = level
|
|
102
|
+
break
|
|
103
|
+
for _ in range(step):
|
|
104
|
+
child_node = cur_node.get(scope_list[idx + 1])
|
|
105
|
+
cur_node = child_node
|
|
106
|
+
idx += 1
|
|
107
|
+
index = -1
|
|
108
|
+
state = data_item.state
|
|
109
|
+
for idx, child in enumerate(cur_node.data_items[state]):
|
|
110
|
+
if data_item.data_name == child.data_name:
|
|
111
|
+
index = idx
|
|
112
|
+
return new_scope, state, index
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_data_items_and_tree(dump_json_path, output_path):
|
|
116
|
+
framework = detect_framework_by_dump_json(dump_json_path)
|
|
117
|
+
stack, construct = get_stack_construct_by_dump_json_path(dump_json_path)
|
|
118
|
+
dump = load_json(dump_json_path)
|
|
119
|
+
dump_data_items = get_dump_data_items(dump, stack, construct, framework, output_path)
|
|
120
|
+
root = LayerTrie(Const.TOP_LAYER, framework)
|
|
121
|
+
for data_item in dump_data_items:
|
|
122
|
+
root.insert(data_item)
|
|
123
|
+
if output_path:
|
|
124
|
+
root.save_to_yaml(output_path)
|
|
125
|
+
return dump_data_items, root
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def convert_data_item(npu_tree, bench_tree, npu_data_item, mapping):
|
|
129
|
+
new_scope, state, index = convert_scope(npu_tree, npu_data_item, mapping)
|
|
130
|
+
bench_data_item = bench_tree.query_data(new_scope, state, index)
|
|
131
|
+
return bench_data_item
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def update_keys_in_place(d):
|
|
135
|
+
"""
|
|
136
|
+
This function is used to compare and maintain compatibility between the old and new versions.
|
|
137
|
+
In the old version, 'Cell' was used as the top layer name, while the new version uses 'TopLayer'.
|
|
138
|
+
"""
|
|
139
|
+
cell_value = d.pop(Const.CELL, None)
|
|
140
|
+
|
|
141
|
+
if cell_value is not None:
|
|
142
|
+
d[Const.TOP_LAYER] = cell_value
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def preprocess_layer_mapping(mapping):
|
|
146
|
+
"""
|
|
147
|
+
before:
|
|
148
|
+
{'A': {'a.b.c': 'new_c',
|
|
149
|
+
'a.demo': 'new_demo',
|
|
150
|
+
'z': 'new_z',
|
|
151
|
+
'd.e': 'e'}}
|
|
152
|
+
after:
|
|
153
|
+
{'A': {'a': [('a.b.c', 'new_c', 3), ('a.demo', 'new_demo', 2)],
|
|
154
|
+
'z': [('z', 'new_z', 1)],
|
|
155
|
+
'd': [('d.e', 'e', 2)]}}
|
|
156
|
+
"""
|
|
157
|
+
update_keys_in_place(mapping)
|
|
158
|
+
final_mapping = {}
|
|
159
|
+
|
|
160
|
+
for type_name, name_map in mapping.items():
|
|
161
|
+
final_mapping[type_name] = {}
|
|
162
|
+
|
|
163
|
+
for key, value in name_map.items():
|
|
164
|
+
key_list = key.split('.')
|
|
165
|
+
prefix = key_list[0] # 取前缀
|
|
166
|
+
key_len = len(key_list)
|
|
167
|
+
if prefix not in final_mapping[type_name]:
|
|
168
|
+
final_mapping[type_name][prefix] = []
|
|
169
|
+
final_mapping[type_name][prefix].append((key, value, key_len))
|
|
170
|
+
|
|
171
|
+
# 前缀映射列表按规则长度排序
|
|
172
|
+
for prefix in final_mapping[type_name]:
|
|
173
|
+
final_mapping[type_name][prefix].sort(key=lambda x: -x[-1])
|
|
174
|
+
|
|
175
|
+
return final_mapping
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def convert_data_items(npu_tree, bench_tree, npu_data_items, mapping):
|
|
179
|
+
mapping = preprocess_layer_mapping(mapping)
|
|
180
|
+
api_mapping = {}
|
|
181
|
+
for npu_data_item in npu_data_items:
|
|
182
|
+
bench_data_item = convert_data_item(npu_tree, bench_tree, npu_data_item, mapping)
|
|
183
|
+
bench_name = bench_data_item.data_name if bench_data_item else CompareConst.N_A
|
|
184
|
+
npu_name = npu_data_item.data_name
|
|
185
|
+
api_mapping[npu_name] = bench_name
|
|
186
|
+
return api_mapping
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def generate_api_mapping_by_layer_mapping(npu_json_path, bench_json_path, layer_mapping_path=None, output_path=None):
|
|
190
|
+
npu_data_items, npu_root = get_data_items_and_tree(npu_json_path, output_path)
|
|
191
|
+
_, bench_root = get_data_items_and_tree(bench_json_path, output_path)
|
|
192
|
+
if isinstance(layer_mapping_path, str):
|
|
193
|
+
mapping = load_yaml(layer_mapping_path)
|
|
194
|
+
else:
|
|
195
|
+
mapping = {}
|
|
196
|
+
api_mapping = convert_data_items(npu_root, bench_root, npu_data_items, mapping)
|
|
197
|
+
if output_path:
|
|
198
|
+
file_name = add_time_with_yaml("api_mapping")
|
|
199
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
200
|
+
save_yaml(file_path, api_mapping)
|
|
201
|
+
return api_mapping
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_path=None):
|
|
205
|
+
def read_full_op_names(data, op_name):
|
|
206
|
+
op_parsed_list = read_op(data.get(op_name, {}), op_name)
|
|
207
|
+
full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list]
|
|
208
|
+
return full_op_names
|
|
209
|
+
|
|
210
|
+
def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names):
|
|
211
|
+
suffix_to_full_op_name = {}
|
|
212
|
+
op_data_mapping = {}
|
|
213
|
+
for bench_full_op_name in bench_full_op_names:
|
|
214
|
+
suffix = bench_full_op_name[len(bench_op_name):]
|
|
215
|
+
suffix_to_full_op_name[suffix] = bench_full_op_name
|
|
216
|
+
|
|
217
|
+
for npu_full_op_name in npu_full_op_names:
|
|
218
|
+
suffix = npu_full_op_name[len(npu_op_name):]
|
|
219
|
+
op_data_mapping[npu_full_op_name] = suffix_to_full_op_name.get(suffix, CompareConst.N_A)
|
|
220
|
+
return op_data_mapping
|
|
221
|
+
|
|
222
|
+
npu_data = load_json(npu_json_path).get("data", {})
|
|
223
|
+
bench_data = load_json(bench_json_path).get("data", {})
|
|
224
|
+
data_mapping = {}
|
|
225
|
+
for npu_op_name, bench_op_name in api_mapping.items():
|
|
226
|
+
if not npu_op_name:
|
|
227
|
+
continue
|
|
228
|
+
npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
|
|
229
|
+
bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
|
|
230
|
+
npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names)
|
|
231
|
+
bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names)
|
|
232
|
+
mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder,
|
|
233
|
+
bench_op_name, bench_full_op_names_reorder)
|
|
234
|
+
data_mapping.update(mapping)
|
|
235
|
+
if output_path:
|
|
236
|
+
file_name = add_time_with_yaml("data_mapping")
|
|
237
|
+
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
238
|
+
save_yaml(file_path, data_mapping)
|
|
239
|
+
return data_mapping
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def generate_data_mapping_by_layer_mapping(input_param, layer_mapping_path=None, output_path=None):
|
|
243
|
+
npu_json_path = input_param.get("npu_json_path")
|
|
244
|
+
bench_json_path = input_param.get("bench_json_path")
|
|
245
|
+
api_mapping = generate_api_mapping_by_layer_mapping(
|
|
246
|
+
npu_json_path, bench_json_path, layer_mapping_path)
|
|
247
|
+
data_mapping = generate_data_mapping(
|
|
248
|
+
npu_json_path, bench_json_path, api_mapping, output_path)
|
|
249
|
+
return data_mapping
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import re
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def postprocess_pass(data_items, name2item):
|
|
22
|
+
backward_pass(data_items, name2item)
|
|
23
|
+
renumber_index_pass(data_items, "ParallelTransformer", "layers")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def backward_pass(data_items, name2item):
|
|
27
|
+
# 处理反向数据,反向无栈信息,沿用正向数据栈信息
|
|
28
|
+
for data_item in data_items:
|
|
29
|
+
data_name_list = data_item.data_name.split(Const.SEP)
|
|
30
|
+
if not data_name_list:
|
|
31
|
+
continue
|
|
32
|
+
if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX:]:
|
|
33
|
+
data_name_list[Const.SCOPE_DIRECTION_INDEX:] = [
|
|
34
|
+
s.replace(Const.BACKWARD, Const.FORWARD)
|
|
35
|
+
for s in data_name_list[Const.SCOPE_DIRECTION_INDEX:]
|
|
36
|
+
]
|
|
37
|
+
forward_name = Const.SEP.join(data_name_list)
|
|
38
|
+
forward_item = name2item.get(forward_name, None)
|
|
39
|
+
if not forward_item:
|
|
40
|
+
continue
|
|
41
|
+
data_item.stack_scope = forward_item.stack_scope
|
|
42
|
+
data_item.full_scope = forward_item.full_scope
|
|
43
|
+
data_item.layer_scope = forward_item.layer_scope
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def extract_next_item_last_number(data, prefix, default_result=None):
|
|
47
|
+
result = default_result
|
|
48
|
+
match = re.search(rf"^{re.escape(prefix)}\.(\S+?)(?:\.|$)", data)
|
|
49
|
+
if match:
|
|
50
|
+
next_item = match.group(1)
|
|
51
|
+
numbers = re.findall(r"\d+", next_item)
|
|
52
|
+
if numbers:
|
|
53
|
+
result = int(numbers[-1])
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def replace_next_item_index(full_scope, prefix, index):
|
|
58
|
+
if math.isinf(index):
|
|
59
|
+
return full_scope
|
|
60
|
+
prefix_pattern = rf"^{re.escape(prefix)}\."
|
|
61
|
+
result = full_scope
|
|
62
|
+
match = re.search(rf"{prefix_pattern}(\S+?)(?:\.|$)", full_scope)
|
|
63
|
+
if match:
|
|
64
|
+
next_item = match.group(1)
|
|
65
|
+
pattern = rf"{prefix_pattern}{re.escape(next_item)}"
|
|
66
|
+
result = re.sub(pattern, f"{prefix}.{index}", full_scope, count=1)
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def renumber_index_pass(data_items, type_name, suffix=None):
|
|
71
|
+
"""
|
|
72
|
+
该函数为解决并行切分场景中编号不一致的比对问题。例如在MindSpore中ParallelTransformer层的PP切分场景,
|
|
73
|
+
MindSpore中的layers的成员编号是全局的,而在PyTorch中编号为局部的。
|
|
74
|
+
为适配此种场景,对指定层的索引进行重新编号,以确保在后续处理阶段序号对齐。
|
|
75
|
+
"""
|
|
76
|
+
prefix_dict = {} # 保存类型为type_name的前缀和最小编号的映射
|
|
77
|
+
for data_item in data_items:
|
|
78
|
+
if data_item.type_name == type_name:
|
|
79
|
+
prefix = f"{data_item.full_scope}.{suffix}" if suffix else data_item.layer_scope
|
|
80
|
+
prefix_dict[prefix] = math.inf
|
|
81
|
+
|
|
82
|
+
# 计算前缀对应的最小编号
|
|
83
|
+
for prefix in prefix_dict:
|
|
84
|
+
for data_item in data_items:
|
|
85
|
+
res = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
|
|
86
|
+
prefix_dict[prefix] = min(prefix_dict[prefix], res)
|
|
87
|
+
|
|
88
|
+
# 重新编号
|
|
89
|
+
for prefix, min_index in prefix_dict.items():
|
|
90
|
+
for data_item in data_items:
|
|
91
|
+
full_scope = data_item.full_scope
|
|
92
|
+
abs_index = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
|
|
93
|
+
rel_index = abs_index - min_index
|
|
94
|
+
full_scope = replace_next_item_index(full_scope, prefix, rel_index)
|
|
95
|
+
data_item.full_scope = full_scope
|