mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import load_json
|
|
20
|
+
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
+
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
22
|
+
from msprobe.visualization.graph.graph import Graph
|
|
23
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
24
|
+
from msprobe.visualization.utils import save_json_file, GraphConst
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GraphBuilder:
|
|
28
|
+
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
29
|
+
# 匹配以大写字母开头,后接任意字母,并以Template(结尾
|
|
30
|
+
template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
|
|
34
|
+
"""
|
|
35
|
+
GraphBuilder的对外提供的构图方法
|
|
36
|
+
Args:
|
|
37
|
+
construct_path: construct.json路径
|
|
38
|
+
data_path: dump.json路径
|
|
39
|
+
stack_path: stack.json路径
|
|
40
|
+
model_name: 模型名字,依赖外部输入
|
|
41
|
+
complete_stack: 完整的堆栈信息
|
|
42
|
+
Returns: Graph,代表图的数据结构
|
|
43
|
+
"""
|
|
44
|
+
construct_dict = load_json(construct_path)
|
|
45
|
+
dump_dict = load_json(data_path)
|
|
46
|
+
stack_dict = load_json(stack_path)
|
|
47
|
+
if not complete_stack:
|
|
48
|
+
GraphBuilder._simplify_stack(stack_dict)
|
|
49
|
+
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
50
|
+
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
51
|
+
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
52
|
+
GraphBuilder._collect_apis_between_modules(graph)
|
|
53
|
+
return graph
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def to_json(filename, config):
|
|
57
|
+
"""
|
|
58
|
+
将graph导出成.vis文件的接口
|
|
59
|
+
"""
|
|
60
|
+
result = {}
|
|
61
|
+
if config.graph_b:
|
|
62
|
+
result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict()
|
|
63
|
+
result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict()
|
|
64
|
+
else:
|
|
65
|
+
result = config.graph_n.to_dict()
|
|
66
|
+
if config.tool_tip:
|
|
67
|
+
result[GraphConst.JSON_TIP_KEY] = config.tool_tip
|
|
68
|
+
if config.node_colors:
|
|
69
|
+
result[GraphConst.COLORS] = config.node_colors
|
|
70
|
+
if config.micro_steps:
|
|
71
|
+
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
72
|
+
if config.task:
|
|
73
|
+
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
74
|
+
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
75
|
+
save_json_file(filename, result)
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _simplify_stack(stack_dict):
|
|
79
|
+
"""
|
|
80
|
+
精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
|
|
81
|
+
|
|
82
|
+
例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
|
|
83
|
+
保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
|
|
84
|
+
|
|
85
|
+
例如Api Tensor.__iadd__.4.forward,堆栈为:
|
|
86
|
+
"File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
|
|
87
|
+
"File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
|
|
88
|
+
匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
|
|
89
|
+
"""
|
|
90
|
+
module_pattern = re.compile(op_patterns[0])
|
|
91
|
+
for dump_name, stack_list in stack_dict.items():
|
|
92
|
+
if not isinstance(stack_list, list):
|
|
93
|
+
continue
|
|
94
|
+
if module_pattern.match(dump_name):
|
|
95
|
+
parts = dump_name.split(Const.SEP)
|
|
96
|
+
if len(parts) < abs(Const.LAYER_NAME_INDEX):
|
|
97
|
+
continue
|
|
98
|
+
module_name = parts[Const.LAYER_NAME_INDEX]
|
|
99
|
+
for stack in stack_list:
|
|
100
|
+
if re.search(module_name + r'\(', stack):
|
|
101
|
+
stack_list = [stack]
|
|
102
|
+
break
|
|
103
|
+
else:
|
|
104
|
+
for index, stack in enumerate(stack_list):
|
|
105
|
+
if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
|
|
106
|
+
stack_list = [stack_list[index + 1]]
|
|
107
|
+
break
|
|
108
|
+
stack_dict[dump_name] = stack_list
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
112
|
+
"""
|
|
113
|
+
如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
|
|
114
|
+
"""
|
|
115
|
+
# 匹配以.backward.后跟一个或多个数字结尾的模式
|
|
116
|
+
backward_pattern = r"(\.backward\.)(\d+)$"
|
|
117
|
+
forward_pattern = r"(\.forward\.)(\d+)$"
|
|
118
|
+
if re.search(backward_pattern, subnode_id) and not upnode_id:
|
|
119
|
+
forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
|
|
120
|
+
if forward_upnode_id:
|
|
121
|
+
new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
|
|
122
|
+
if new_upnode_id in construct_dict:
|
|
123
|
+
return new_upnode_id
|
|
124
|
+
return upnode_id
|
|
125
|
+
|
|
126
|
+
@staticmethod
|
|
127
|
+
def _init_nodes(graph, construct_dict, data_dict, stack_dict):
|
|
128
|
+
for subnode_id, upnode_id in construct_dict.items():
|
|
129
|
+
upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id)
|
|
130
|
+
if upnode_id:
|
|
131
|
+
upnode_op = NodeOp.get_node_op(upnode_id)
|
|
132
|
+
upnode = GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], upnode_op, upnode_id)
|
|
133
|
+
else:
|
|
134
|
+
upnode = graph.root
|
|
135
|
+
node_op = NodeOp.get_node_op(subnode_id)
|
|
136
|
+
GraphBuilder._create_or_get_node(graph, [data_dict, stack_dict], node_op, subnode_id, upnode)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _create_or_get_node(graph, data_stack_list, op, name, upnode=None):
|
|
140
|
+
if name in graph.node_map:
|
|
141
|
+
node = graph.get_node(name)
|
|
142
|
+
else:
|
|
143
|
+
graph.add_node(op, name, upnode)
|
|
144
|
+
node = graph.get_node(name)
|
|
145
|
+
node_data = data_stack_list[0].get(name, {})
|
|
146
|
+
node_stack_info = data_stack_list[1].get(name, [])
|
|
147
|
+
# 添加输入输出数据
|
|
148
|
+
input_data, output_data = get_input_output(node_data, node.id)
|
|
149
|
+
# 更新数据
|
|
150
|
+
node.set_input_output(input_data, output_data)
|
|
151
|
+
# 反向节点使用对应前向节点的堆栈信息
|
|
152
|
+
# 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
|
|
153
|
+
if (not node_stack_info and
|
|
154
|
+
(GraphBuilder.backward_pattern.search(name) or name.endswith(f'{Const.SEP}{Const.BACKWARD}'))):
|
|
155
|
+
forward_node = graph.get_node(
|
|
156
|
+
# 同名模块全局唯一,无论调用几次堆栈信息都一致,直接使用编号0的同名模块堆栈信息,避免遗漏
|
|
157
|
+
GraphBuilder.backward_pattern.sub(f'{Const.SEP}{Const.FORWARD}{Const.SEP}0', name)) \
|
|
158
|
+
if GraphBuilder.backward_pattern.search(name) \
|
|
159
|
+
else graph.get_node(name.replace(Const.BACKWARD, Const.FORWARD))
|
|
160
|
+
node_stack_info = forward_node.stack_info if forward_node \
|
|
161
|
+
else ['This backward node cannot find the forward node and cannot retrieve stack information.']
|
|
162
|
+
node.stack_info = node_stack_info
|
|
163
|
+
# 添加节点
|
|
164
|
+
node.add_upnode(upnode)
|
|
165
|
+
return node
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _collect_apis_between_modules(graph):
|
|
169
|
+
"""
|
|
170
|
+
图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点
|
|
171
|
+
Args:
|
|
172
|
+
graph: 模型结构
|
|
173
|
+
|
|
174
|
+
Returns: None
|
|
175
|
+
"""
|
|
176
|
+
i = 0
|
|
177
|
+
output = []
|
|
178
|
+
node_list = graph.root.subnodes
|
|
179
|
+
while i < len(node_list):
|
|
180
|
+
current_node = node_list[i]
|
|
181
|
+
|
|
182
|
+
# 当前节点为api,检查后续是否还有api
|
|
183
|
+
if current_node.op == NodeOp.function_api:
|
|
184
|
+
temp_nodes = [current_node]
|
|
185
|
+
i += 1
|
|
186
|
+
while i < len(node_list) and node_list[i].op == NodeOp.function_api:
|
|
187
|
+
temp_nodes.append(node_list[i])
|
|
188
|
+
i += 1
|
|
189
|
+
|
|
190
|
+
# 检查api节点是否大于等于2个
|
|
191
|
+
if len(temp_nodes) >= 2:
|
|
192
|
+
# 创建新节点,将这些api节点放入新节点的subnodes属性
|
|
193
|
+
node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES,
|
|
194
|
+
id_accumulation=True)
|
|
195
|
+
api_collection_node = graph.get_node(node_id)
|
|
196
|
+
api_collection_node.subnodes = temp_nodes
|
|
197
|
+
# 重新确立父子关系
|
|
198
|
+
for node in temp_nodes:
|
|
199
|
+
node.upnode = api_collection_node
|
|
200
|
+
api_collection_node.upnode = graph.root
|
|
201
|
+
output.append(api_collection_node)
|
|
202
|
+
else:
|
|
203
|
+
# 如果连续的api节点不足2个,将它们原样添加到输出列表
|
|
204
|
+
output.extend(temp_nodes)
|
|
205
|
+
else:
|
|
206
|
+
# 如果当前节点为module,直接添加到输出列表
|
|
207
|
+
output.append(current_node)
|
|
208
|
+
i += 1
|
|
209
|
+
|
|
210
|
+
graph.root.subnodes = output
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class GraphExportConfig:
|
|
214
|
+
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
215
|
+
overflow_check=False):
|
|
216
|
+
self.graph_n = graph_n
|
|
217
|
+
self.graph_b = graph_b
|
|
218
|
+
self.tool_tip = tool_tip
|
|
219
|
+
self.node_colors = node_colors
|
|
220
|
+
self.micro_steps = micro_steps
|
|
221
|
+
self.task = task
|
|
222
|
+
self.overflow_check = overflow_check
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import re
|
|
16
|
+
import math
|
|
17
|
+
from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
18
|
+
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
|
+
from msprobe.visualization.utils import GraphConst
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
22
|
+
|
|
23
|
+
# 用于将节点名字解析成对应的NodeOp的规则
|
|
24
|
+
op_patterns = [
|
|
25
|
+
# NodeOp.module
|
|
26
|
+
r'^(Module.|Cell.)',
|
|
27
|
+
# NodeOp.function_api
|
|
28
|
+
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_compare_mode(dump_path_param):
|
|
33
|
+
"""
|
|
34
|
+
获得比较模式,包括summary、MD5和真实数据三种模式
|
|
35
|
+
Args:
|
|
36
|
+
dump_path_param: 调用acc_compare接口所依赖的参数
|
|
37
|
+
Returns: 0 summary mode, 1 md5 mode, 2 true data mode
|
|
38
|
+
"""
|
|
39
|
+
set_dump_path(dump_path_param)
|
|
40
|
+
dump_mode = get_dump_mode(dump_path_param)
|
|
41
|
+
compare_mode = GraphConst.DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING.get(dump_mode)
|
|
42
|
+
return compare_mode
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
46
|
+
"""
|
|
47
|
+
多进程运行生成真实数据
|
|
48
|
+
Args:
|
|
49
|
+
dump_path_param: 调用acc_compare接口所依赖的参数
|
|
50
|
+
csv_path: 生成文件路径
|
|
51
|
+
framework: 框架类型, pytorch或mindspore
|
|
52
|
+
is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
|
|
53
|
+
"""
|
|
54
|
+
mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
|
|
55
|
+
|
|
56
|
+
if framework == Const.PT_FRAMEWORK:
|
|
57
|
+
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
58
|
+
return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
|
|
59
|
+
else:
|
|
60
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
61
|
+
ms_comparator = MSComparator(mode_config)
|
|
62
|
+
ms_comparator.cross_frame = is_cross_frame
|
|
63
|
+
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_input_output(node_data, node_id):
|
|
67
|
+
"""
|
|
68
|
+
将dump的原始数据进行拆解,分解为output和input两个数据
|
|
69
|
+
Args:
|
|
70
|
+
node_data: 属于单个节点的dump数据
|
|
71
|
+
node_id: 节点名字
|
|
72
|
+
"""
|
|
73
|
+
input_data = {}
|
|
74
|
+
output_data = {}
|
|
75
|
+
op_parsed_list = read_op(node_data, node_id)
|
|
76
|
+
for item in op_parsed_list:
|
|
77
|
+
full_op_name = item.get('full_op_name', '')
|
|
78
|
+
if not full_op_name:
|
|
79
|
+
continue
|
|
80
|
+
if GraphConst.OUTPUT in full_op_name and GraphConst.INPUT not in full_op_name:
|
|
81
|
+
output_data[full_op_name] = item
|
|
82
|
+
else:
|
|
83
|
+
name = item.get('data_name')
|
|
84
|
+
# 节点参数名称尽量使用落盘数据的名称
|
|
85
|
+
if isinstance(name, str) and name != '-1':
|
|
86
|
+
input_data[name.rsplit(Const.SEP, 1)[0]] = item
|
|
87
|
+
else:
|
|
88
|
+
input_data[full_op_name] = item
|
|
89
|
+
return input_data, output_data
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def compare_data(data_dict_list1, data_dict_list2):
|
|
93
|
+
"""
|
|
94
|
+
比较get_input_output中输出的结果是否结构一致,比较一致返回True
|
|
95
|
+
"""
|
|
96
|
+
if len(data_dict_list1) != len(data_dict_list2):
|
|
97
|
+
return False
|
|
98
|
+
# 用于比较两个节点是否相等的关键字段
|
|
99
|
+
tag_keys = ['type', 'shape']
|
|
100
|
+
for key1, key2 in zip(data_dict_list1, data_dict_list2):
|
|
101
|
+
dict1 = data_dict_list1[key1]
|
|
102
|
+
dict2 = data_dict_list2[key2]
|
|
103
|
+
for tag_key in tag_keys:
|
|
104
|
+
tag_value1 = dict1.get(tag_key, None)
|
|
105
|
+
tag_value2 = dict2.get(tag_key, None)
|
|
106
|
+
if tag_value1 != tag_value2:
|
|
107
|
+
return False
|
|
108
|
+
return True
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def compare_data_fuzzy(data_dict_list1, data_dict_list2):
|
|
112
|
+
"""
|
|
113
|
+
模糊匹配,仅校验参数shape是否一致
|
|
114
|
+
"""
|
|
115
|
+
for x, y in zip(data_dict_list1.values(), data_dict_list2.values()):
|
|
116
|
+
x_shape = x.get(Const.SHAPE)
|
|
117
|
+
y_shape = y.get(Const.SHAPE)
|
|
118
|
+
if x_shape != y_shape:
|
|
119
|
+
return False
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def format_node_data(data_dict):
|
|
124
|
+
"""
|
|
125
|
+
批量进行节点数据的输出
|
|
126
|
+
"""
|
|
127
|
+
del_list = ['requires_grad', 'full_op_name']
|
|
128
|
+
for _, value in data_dict.items():
|
|
129
|
+
if not isinstance(value, dict):
|
|
130
|
+
continue
|
|
131
|
+
for item in del_list:
|
|
132
|
+
if item in value:
|
|
133
|
+
del value[item]
|
|
134
|
+
_format_data(value)
|
|
135
|
+
return data_dict
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
|
|
139
|
+
"""
|
|
140
|
+
调用acc_compare.py中的get_accuracy获得精度对比指标
|
|
141
|
+
真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
|
|
142
|
+
Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
|
|
143
|
+
"""
|
|
144
|
+
merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
|
|
145
|
+
merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
|
|
146
|
+
result = []
|
|
147
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
148
|
+
get_accuracy(result, merge_n, merge_b, dump_mode)
|
|
149
|
+
return result
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
|
|
153
|
+
"""
|
|
154
|
+
转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
|
|
155
|
+
"""
|
|
156
|
+
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
157
|
+
op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
|
|
158
|
+
if node_id in stack_json_data:
|
|
159
|
+
op_parsed_list.append(
|
|
160
|
+
{'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
|
|
161
|
+
else:
|
|
162
|
+
op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
|
|
163
|
+
result = merge_tensor(op_parsed_list, dump_mode)
|
|
164
|
+
if not result:
|
|
165
|
+
result['op_name'] = []
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _format_decimal_string(s):
|
|
170
|
+
"""
|
|
171
|
+
使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
|
|
172
|
+
"""
|
|
173
|
+
pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
|
|
174
|
+
matches = pattern.findall(s)
|
|
175
|
+
for match in matches:
|
|
176
|
+
is_percent = match.endswith('%')
|
|
177
|
+
number_str = match.rstrip('%')
|
|
178
|
+
decimal_part = number_str.split('.')[1]
|
|
179
|
+
# 如果小数位数大于6,进行处理
|
|
180
|
+
if len(decimal_part) > GraphConst.ROUND_TH:
|
|
181
|
+
number_float = float(number_str)
|
|
182
|
+
formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}"
|
|
183
|
+
# 如果原来是百分数,加回百分号
|
|
184
|
+
if is_percent:
|
|
185
|
+
formatted_number += '%'
|
|
186
|
+
# 替换原字符串中的数值部分
|
|
187
|
+
s = s.replace(match, formatted_number)
|
|
188
|
+
return s
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _format_data(data_dict):
|
|
192
|
+
"""
|
|
193
|
+
格式化数据,小数保留6位,处理一些异常值
|
|
194
|
+
"""
|
|
195
|
+
pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
|
|
196
|
+
all_null = False
|
|
197
|
+
|
|
198
|
+
keys_to_keep = ['type', 'group_ranks', 'group_id', 'data_name']
|
|
199
|
+
if data_dict.get('type') == 'torch.ProcessGroup':
|
|
200
|
+
keys_to_remove = [key for key in data_dict if key not in keys_to_keep]
|
|
201
|
+
for key in keys_to_remove:
|
|
202
|
+
del data_dict[key]
|
|
203
|
+
|
|
204
|
+
for key, value in data_dict.items():
|
|
205
|
+
if isinstance(value, str):
|
|
206
|
+
# 将单引号删掉,None换成null避免前端解析错误
|
|
207
|
+
value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL)
|
|
208
|
+
value = _format_decimal_string(value)
|
|
209
|
+
elif value is None or value == ' ':
|
|
210
|
+
value = GraphConst.NULL
|
|
211
|
+
# 科学计数法1.123123123123e-11,格式化为1.123123e-11
|
|
212
|
+
elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)):
|
|
213
|
+
value = "{:.6e}".format(value)
|
|
214
|
+
elif isinstance(value, float):
|
|
215
|
+
value = round(value, GraphConst.ROUND_TH)
|
|
216
|
+
# Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案
|
|
217
|
+
if key != GraphConst.ERROR_KEY:
|
|
218
|
+
# 除了error_key不转str,其他都转str, 避免前端解析错误
|
|
219
|
+
value = str(value)
|
|
220
|
+
# max为null, 意味着这个参数值为null
|
|
221
|
+
if key == Const.MAX and value == GraphConst.NULL:
|
|
222
|
+
all_null = True
|
|
223
|
+
data_dict[key] = value
|
|
224
|
+
# 字典里的value全null,只保留一个null
|
|
225
|
+
if all_null:
|
|
226
|
+
data_dict.clear()
|
|
227
|
+
data_dict[GraphConst.VALUE] = GraphConst.NULL
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
18
|
+
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
|
|
19
|
+
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
20
|
+
from msprobe.visualization.graph.node_colors import NodeColors
|
|
21
|
+
from msprobe.visualization.compare.mode_adapter import ModeAdapter
|
|
22
|
+
from msprobe.core.common.const import Const
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GraphComparator:
|
|
26
|
+
def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
|
|
27
|
+
self.graph_n = graphs[0]
|
|
28
|
+
self.graph_b = graphs[1]
|
|
29
|
+
self._parse_param(dump_path_param, args.output_path)
|
|
30
|
+
self.framework = args.framework
|
|
31
|
+
self.mapping_dict = mapping_dict
|
|
32
|
+
self.fuzzy_match = args.fuzzy_match
|
|
33
|
+
self.pattern = re.compile(r'\.\d+\.')
|
|
34
|
+
|
|
35
|
+
def compare(self):
|
|
36
|
+
"""
|
|
37
|
+
比较函数,初始化结束后单独调用。比较结果写入graph_n
|
|
38
|
+
"""
|
|
39
|
+
if self.fuzzy_match:
|
|
40
|
+
self._compare_nodes_fuzzy(self.graph_n.root)
|
|
41
|
+
else:
|
|
42
|
+
self._compare_nodes(self.graph_n.root)
|
|
43
|
+
self._postcompare()
|
|
44
|
+
|
|
45
|
+
def add_compare_result_to_node(self, node, compare_result_list):
|
|
46
|
+
"""
|
|
47
|
+
将比对结果添加到节点的输入输出数据中
|
|
48
|
+
Args:
|
|
49
|
+
node: 节点
|
|
50
|
+
compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list
|
|
51
|
+
"""
|
|
52
|
+
# 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中
|
|
53
|
+
if self.ma.prepare_real_data(node):
|
|
54
|
+
return
|
|
55
|
+
compare_in_dict = {}
|
|
56
|
+
compare_out_dict = {}
|
|
57
|
+
# input和output对比数据分开
|
|
58
|
+
for item in compare_result_list:
|
|
59
|
+
if not isinstance(item, (list, tuple)) or not item:
|
|
60
|
+
continue
|
|
61
|
+
if '.output.' in item[0]:
|
|
62
|
+
compare_out_dict[item[0]] = item
|
|
63
|
+
else:
|
|
64
|
+
compare_in_dict[item[0]] = item
|
|
65
|
+
precision_index, other_dict = (
|
|
66
|
+
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
67
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
68
|
+
node.data.update(other_dict)
|
|
69
|
+
|
|
70
|
+
def _parse_param(self, dump_path_param, output_path):
|
|
71
|
+
self.dump_path_param = dump_path_param
|
|
72
|
+
self.output_path = output_path
|
|
73
|
+
compare_mode = get_compare_mode(self.dump_path_param)
|
|
74
|
+
self.ma = ModeAdapter(compare_mode)
|
|
75
|
+
self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
|
|
76
|
+
self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
|
|
77
|
+
self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
|
|
78
|
+
|
|
79
|
+
def _postcompare(self):
|
|
80
|
+
self._handle_api_collection_index()
|
|
81
|
+
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
82
|
+
return
|
|
83
|
+
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
84
|
+
df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
|
|
85
|
+
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
86
|
+
for node in self.ma.compare_nodes:
|
|
87
|
+
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
88
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
89
|
+
|
|
90
|
+
def _handle_api_collection_index(self):
|
|
91
|
+
"""
|
|
92
|
+
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
93
|
+
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
94
|
+
"""
|
|
95
|
+
for node in self.graph_n.root.subnodes:
|
|
96
|
+
if node.op == NodeOp.api_collection:
|
|
97
|
+
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
98
|
+
else GraphConst.MIN_INDEX_KEY
|
|
99
|
+
for api in node.subnodes:
|
|
100
|
+
precision_index = min(precision_index,
|
|
101
|
+
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
102
|
+
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
103
|
+
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
104
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
105
|
+
|
|
106
|
+
def _compare_nodes(self, node_n):
|
|
107
|
+
"""
|
|
108
|
+
递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
109
|
+
这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
|
|
110
|
+
"""
|
|
111
|
+
if self.mapping_dict:
|
|
112
|
+
node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
|
|
113
|
+
if node_b:
|
|
114
|
+
ancestors_n.append(node_n.id)
|
|
115
|
+
ancestors_b.append(node_b.id)
|
|
116
|
+
node_n.matched_node_link = ancestors_b
|
|
117
|
+
node_b.matched_node_link = ancestors_n
|
|
118
|
+
else:
|
|
119
|
+
node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
|
|
120
|
+
if node_b:
|
|
121
|
+
ancestors.append(node_b.id)
|
|
122
|
+
node_n.add_link(node_b, ancestors)
|
|
123
|
+
if node_b:
|
|
124
|
+
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
125
|
+
self._get_and_add_result(node_n, node_b)
|
|
126
|
+
for subnode in node_n.subnodes:
|
|
127
|
+
self._compare_nodes(subnode)
|
|
128
|
+
|
|
129
|
+
def _compare_nodes_fuzzy(self, node_n):
|
|
130
|
+
if node_n.op != NodeOp.function_api:
|
|
131
|
+
# 模块经过模糊匹配
|
|
132
|
+
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
133
|
+
if node_b:
|
|
134
|
+
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
135
|
+
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
136
|
+
recount_result_n = self._recount_api_node(node_n)
|
|
137
|
+
recount_result_b = self._recount_api_node(node_b)
|
|
138
|
+
for recount_node_id, node_id_n in recount_result_n.items():
|
|
139
|
+
api_node_n = self.graph_n.node_map.get(node_id_n)
|
|
140
|
+
if not api_node_n:
|
|
141
|
+
continue
|
|
142
|
+
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
143
|
+
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
144
|
+
if api_node_b:
|
|
145
|
+
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
146
|
+
for sub_node in node_n.subnodes:
|
|
147
|
+
self._compare_nodes_fuzzy(sub_node)
|
|
148
|
+
|
|
149
|
+
def _get_and_add_result(self, node_n, node_b):
|
|
150
|
+
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
151
|
+
[self.data_n_dict, self.data_b_dict],
|
|
152
|
+
self.stack_json_data, self.ma.compare_mode)
|
|
153
|
+
if compare_result_list:
|
|
154
|
+
self.ma.add_csv_data(compare_result_list)
|
|
155
|
+
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
156
|
+
|
|
157
|
+
def _recount_api_node(self, node):
|
|
158
|
+
"""
|
|
159
|
+
两个匹配上的模块, 忽略各自模块下所有api的dump调用次数, 并赋予模块中的调用顺序
|
|
160
|
+
Return:
|
|
161
|
+
{赋予模块中的调用顺序的node_id: 原始node_id}
|
|
162
|
+
"""
|
|
163
|
+
recount_result = {}
|
|
164
|
+
node_count = {}
|
|
165
|
+
for sub_node in node.subnodes:
|
|
166
|
+
if sub_node.op == NodeOp.function_api:
|
|
167
|
+
# 忽略dump调用次数
|
|
168
|
+
count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
|
|
169
|
+
node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
|
|
170
|
+
# 赋予模块中的调用顺序
|
|
171
|
+
recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
|
|
172
|
+
recount_result[recount_node_id] = sub_node.id
|
|
173
|
+
return recount_result
|
|
174
|
+
|
|
175
|
+
def _process_matched_nodes(self, node_n, node_b, ancestors_n, ancestors_b):
|
|
176
|
+
ancestors_n.append(node_n.id)
|
|
177
|
+
ancestors_b.append(node_b.id)
|
|
178
|
+
node_n.matched_node_link = ancestors_b
|
|
179
|
+
node_b.matched_node_link = ancestors_n
|
|
180
|
+
self._get_and_add_result(node_n, node_b)
|