mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
- mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +51 -20
- msprobe/config.json +2 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +264 -15
- msprobe/core/common/exceptions.py +27 -3
- msprobe/core/common/file_utils.py +176 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +204 -77
- msprobe/core/common_config.py +49 -14
- msprobe/core/compare/acc_compare.py +274 -198
- msprobe/core/compare/check.py +32 -33
- msprobe/core/compare/compare_cli.py +32 -14
- msprobe/core/compare/highlight.py +283 -127
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +135 -144
- msprobe/core/compare/utils.py +419 -274
- msprobe/core/data_dump/data_collector.py +60 -28
- msprobe/core/data_dump/data_processor/base.py +84 -36
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
- msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +119 -39
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +189 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +96 -7
- msprobe/docs/02.config_introduction.md +50 -23
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +93 -61
- msprobe/docs/06.data_dump_MindSpore.md +200 -95
- msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
- msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +561 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +466 -0
- msprobe/docs/22.visualization_MindSpore.md +481 -0
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/25.tool_function_introduction.md +29 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +29 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +25 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
- msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
- msprobe/mindspore/api_accuracy_checker/main.py +28 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +35 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +88 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -24
- msprobe/mindspore/compare/ms_compare.py +333 -268
- msprobe/mindspore/compare/ms_graph_compare.py +95 -52
- msprobe/mindspore/debugger/debugger_config.py +7 -1
- msprobe/mindspore/debugger/precision_debugger.py +87 -12
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +35 -12
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +27 -16
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +285 -113
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +48 -10
- msprobe/pytorch/__init__.py +8 -6
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +7 -6
- msprobe/pytorch/common/utils.py +101 -7
- msprobe/pytorch/compare/distributed_compare.py +17 -30
- msprobe/pytorch/compare/pt_compare.py +44 -22
- msprobe/pytorch/debugger/debugger_config.py +46 -27
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +10 -2
- msprobe/pytorch/free_benchmark/common/utils.py +29 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -38
- msprobe/pytorch/monitor/__init__.py +0 -0
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +425 -0
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +1076 -0
- msprobe/pytorch/monitor/module_metric.py +172 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
- msprobe/pytorch/monitor/optimizer_collect.py +333 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
- msprobe/pytorch/monitor/utils.py +321 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +58 -27
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +21 -27
- msprobe/pytorch/parse_tool/lib/config.py +6 -8
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
- msprobe/pytorch/parse_tool/lib/utils.py +33 -53
- msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +188 -108
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +222 -0
- msprobe/visualization/builder/msprobe_adapter.py +227 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +180 -0
- msprobe/visualization/compare/mode_adapter.py +197 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +119 -0
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +209 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +288 -0
- msprobe/visualization/utils.py +217 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
- /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
from msprobe.core.overflow_check.checker import AnomalyDetector
|
|
16
|
+
from msprobe.visualization.graph.base_node import BaseNode
|
|
17
|
+
from msprobe.visualization.graph.node_op import NodeOp
|
|
18
|
+
from msprobe.visualization.utils import GraphConst
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
20
|
+
from msprobe.core.common.const import Const
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
MAX_RECUR_LEVEL = 100
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Graph:
|
|
27
|
+
def __init__(self, model_name, data_path='', dump_data=None):
|
|
28
|
+
self.node_map = {}
|
|
29
|
+
self.node_id_map = {}
|
|
30
|
+
self.add_node(NodeOp.module, model_name)
|
|
31
|
+
self.root = self.get_node(model_name)
|
|
32
|
+
self.data_path = data_path
|
|
33
|
+
self.dump_data = dump_data
|
|
34
|
+
|
|
35
|
+
def __str__(self):
|
|
36
|
+
infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
|
|
37
|
+
info = "\n".join(infos)
|
|
38
|
+
return info
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def match(graph_n, node_n, graph_b):
|
|
42
|
+
"""
|
|
43
|
+
给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成
|
|
44
|
+
目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑
|
|
45
|
+
返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, []
|
|
46
|
+
"""
|
|
47
|
+
if not node_n or node_n.id not in graph_b.node_map:
|
|
48
|
+
return None, []
|
|
49
|
+
node_b = graph_b.node_map.get(node_n.id)
|
|
50
|
+
if node_n != node_b:
|
|
51
|
+
return None, []
|
|
52
|
+
ancestors_n = node_n.get_ancestors()
|
|
53
|
+
ancestors_b = node_b.get_ancestors()
|
|
54
|
+
if ancestors_n != ancestors_b:
|
|
55
|
+
return None, []
|
|
56
|
+
return node_b, ancestors_n
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def mapping_match(node_n, graph_b, mapping_dict):
|
|
60
|
+
"""
|
|
61
|
+
根据映射配置对节点进行匹配
|
|
62
|
+
"""
|
|
63
|
+
node_b = graph_b.node_map.get(mapping_dict.get(node_n.id, node_n.id))
|
|
64
|
+
if not node_b:
|
|
65
|
+
return None, [], []
|
|
66
|
+
ancestors_n = node_n.get_ancestors()
|
|
67
|
+
ancestors_b = node_b.get_ancestors()
|
|
68
|
+
return node_b, ancestors_n, ancestors_b
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def fuzzy_match(node_n, node_b):
|
|
73
|
+
if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
|
|
74
|
+
return None, [], []
|
|
75
|
+
ancestors_n = node_n.get_ancestors()
|
|
76
|
+
ancestors_b = node_b.get_ancestors()
|
|
77
|
+
return node_b, ancestors_n, ancestors_b
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def dfs(node, result):
|
|
81
|
+
info = node.to_dict()
|
|
82
|
+
result[node.id] = info
|
|
83
|
+
for subnode in node.subnodes:
|
|
84
|
+
Graph.dfs(subnode, result)
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def split_nodes_by_micro_step(nodes):
|
|
88
|
+
"""
|
|
89
|
+
根据Module名称, 区分一个step中的多个micro steps.
|
|
90
|
+
一个micro step必须是一次完整的前反向过程
|
|
91
|
+
Example::
|
|
92
|
+
=============== micro step0
|
|
93
|
+
Module.forward
|
|
94
|
+
Module.forward
|
|
95
|
+
...
|
|
96
|
+
Module.backward
|
|
97
|
+
Module.backward
|
|
98
|
+
=============== micro step1
|
|
99
|
+
Module.forward
|
|
100
|
+
Module.forward
|
|
101
|
+
...
|
|
102
|
+
Module.backward
|
|
103
|
+
Module.backward
|
|
104
|
+
=============== micro step2
|
|
105
|
+
Module.forward
|
|
106
|
+
Module.forward
|
|
107
|
+
...
|
|
108
|
+
Module.backward
|
|
109
|
+
Module.backward
|
|
110
|
+
|
|
111
|
+
如果是非Module节点,分类到前一个Module节点所在的micro step.
|
|
112
|
+
"""
|
|
113
|
+
result = {}
|
|
114
|
+
micro_step = 0
|
|
115
|
+
result[micro_step] = []
|
|
116
|
+
backward_flag = False
|
|
117
|
+
|
|
118
|
+
for node in nodes:
|
|
119
|
+
if node.op == NodeOp.module:
|
|
120
|
+
if f'{Const.SEP}{Const.FORWARD}{Const.SEP}' in node.id:
|
|
121
|
+
if backward_flag:
|
|
122
|
+
micro_step += 1
|
|
123
|
+
result[micro_step] = []
|
|
124
|
+
backward_flag = False
|
|
125
|
+
else:
|
|
126
|
+
backward_flag = True
|
|
127
|
+
result[micro_step].append(node)
|
|
128
|
+
return result
|
|
129
|
+
|
|
130
|
+
def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
|
|
131
|
+
"""
|
|
132
|
+
在graph中进行节点的添加
|
|
133
|
+
Args:
|
|
134
|
+
node_op: 需要添加的节点类型
|
|
135
|
+
node_id: 需要添加的节点id
|
|
136
|
+
up_node:对应节点的父节点
|
|
137
|
+
id_accumulation: 是否对传入的重复node_id进行累加
|
|
138
|
+
"""
|
|
139
|
+
if node_id in self.node_map:
|
|
140
|
+
if id_accumulation:
|
|
141
|
+
self.node_id_map[node_id] = 0
|
|
142
|
+
else:
|
|
143
|
+
return node_id
|
|
144
|
+
if id_accumulation:
|
|
145
|
+
if node_id in self.node_id_map:
|
|
146
|
+
self.node_id_map[node_id] += 1
|
|
147
|
+
else:
|
|
148
|
+
self.node_id_map[node_id] = 0
|
|
149
|
+
node_id = f'{node_id}.{self.node_id_map[node_id]}'
|
|
150
|
+
node = BaseNode(node_op, node_id, up_node)
|
|
151
|
+
self.node_map[node_id] = node
|
|
152
|
+
return node_id
|
|
153
|
+
|
|
154
|
+
def get_node(self, node_id):
|
|
155
|
+
"""
|
|
156
|
+
返回节点,不存在返回None
|
|
157
|
+
"""
|
|
158
|
+
return self.node_map.get(node_id, None)
|
|
159
|
+
|
|
160
|
+
def to_dict(self):
|
|
161
|
+
"""
|
|
162
|
+
用于数据输出
|
|
163
|
+
"""
|
|
164
|
+
result = {}
|
|
165
|
+
result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
|
|
166
|
+
result[GraphConst.JSON_DATA_KEY] = self.data_path
|
|
167
|
+
result[GraphConst.JSON_NODE_KEY] = {}
|
|
168
|
+
for node_id in self.node_map:
|
|
169
|
+
info = self.node_map.get(node_id).to_dict()
|
|
170
|
+
result[GraphConst.JSON_NODE_KEY][node_id] = info
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
def paging_by_micro_step(self, graph_other=None):
|
|
174
|
+
"""
|
|
175
|
+
给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
|
|
176
|
+
比对场景中,同步更新另一个图graph_other中相应节点的micro step信息
|
|
177
|
+
Args:
|
|
178
|
+
self: 当前graph
|
|
179
|
+
graph_other: 可选参数,另一个graph
|
|
180
|
+
Returns: 分批的数量
|
|
181
|
+
"""
|
|
182
|
+
batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
|
|
183
|
+
for batch_number, nodes in batches_n.items():
|
|
184
|
+
for node in nodes:
|
|
185
|
+
node.micro_step_id = batch_number
|
|
186
|
+
# 在graph_other中更新已匹配节点的micro_step_id
|
|
187
|
+
if graph_other and node.matched_node_link:
|
|
188
|
+
node_other = graph_other.get_node(node.matched_node_link[-1])
|
|
189
|
+
if node_other:
|
|
190
|
+
node_other.micro_step_id = batch_number
|
|
191
|
+
# 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
|
|
192
|
+
if graph_other:
|
|
193
|
+
for node in graph_other.root.subnodes:
|
|
194
|
+
if node.micro_step_id is None:
|
|
195
|
+
try:
|
|
196
|
+
micro_step_id = int(node.id.split(Const.SEP)[-1])
|
|
197
|
+
except ValueError:
|
|
198
|
+
micro_step_id = 0
|
|
199
|
+
node.micro_step_id = micro_step_id
|
|
200
|
+
return len(batches_n)
|
|
201
|
+
|
|
202
|
+
def overflow_check(self):
|
|
203
|
+
detector = AnomalyDetector(self.dump_data)
|
|
204
|
+
detector.analyze().filter()
|
|
205
|
+
|
|
206
|
+
for node_id, _node in self.node_map.items():
|
|
207
|
+
if detector.has_overflow(node_id):
|
|
208
|
+
lv = detector.get_overflow_level(node_id)
|
|
209
|
+
_node.set_overflow_level(lv)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from msprobe.visualization.utils import GraphConst, ToolTip
|
|
18
|
+
|
|
19
|
+
SUMMARY_DESCRIPTION = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|"
|
|
20
|
+
REAL_DATA_DESCRIPTION = (f"此节点所有输入的最小双千分之一和所有输出的最小双千分之一的差值的绝对值, 代表双千指标的变化情况, "
|
|
21
|
+
f"值越大代表测量值与标杆值的偏差越大, 双千分之一指标计算方式:{ToolTip.ONE_THOUSANDTH_ERR_RATIO}")
|
|
22
|
+
MD5_DESCRIPTION_N = "与标杆相比, 此节点任意输入输出的md5值不同"
|
|
23
|
+
MD5_DESCRIPTION_Y = "与标杆相比, 此节点所有输入输出的md5值相同"
|
|
24
|
+
NOT_MATCHED = "比对过程中节点未匹配上"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class NodeColors(Enum):
|
|
28
|
+
# 枚举值后缀数字越小, 颜色越浅
|
|
29
|
+
# value值左闭右开, 两个值相同代表固定值
|
|
30
|
+
YELLOW_1 = ("#FFFCF3", {
|
|
31
|
+
GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0, 0.2], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
|
|
32
|
+
GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0, 0.05], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION},
|
|
33
|
+
GraphConst.MD5_COMPARE: {GraphConst.VALUE: [1, 1], GraphConst.DESCRIPTION: MD5_DESCRIPTION_Y},
|
|
34
|
+
})
|
|
35
|
+
YELLOW_2 = ("#FFEDBE", {
|
|
36
|
+
GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.2, 0.4], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
|
|
37
|
+
GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.05, 0.1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
|
|
38
|
+
})
|
|
39
|
+
ORANGE_1 = ("#FFDC7F", {
|
|
40
|
+
GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.4, 0.6], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
|
|
41
|
+
GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.1, 0.15], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
|
|
42
|
+
})
|
|
43
|
+
ORANGE_2 = ("#FFC62E", {
|
|
44
|
+
GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.6, 0.8], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
|
|
45
|
+
GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.15, 0.2], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}
|
|
46
|
+
})
|
|
47
|
+
RED = ("#FF704D", {
|
|
48
|
+
GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.8, 1], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION},
|
|
49
|
+
GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.2, 1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION},
|
|
50
|
+
GraphConst.MD5_COMPARE: {GraphConst.VALUE: [0, 0], GraphConst.DESCRIPTION: MD5_DESCRIPTION_N},
|
|
51
|
+
})
|
|
52
|
+
GREY = ("#C7C7C7", {
|
|
53
|
+
GraphConst.VALUE: [], GraphConst.DESCRIPTION: NOT_MATCHED
|
|
54
|
+
})
|
|
55
|
+
|
|
56
|
+
def __init__(self, hex_value, mode_info):
|
|
57
|
+
self.hex_value = hex_value
|
|
58
|
+
self.mode_info = mode_info
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def get_node_colors(mode):
|
|
62
|
+
"""
|
|
63
|
+
获取不同比对模式下的颜色说明
|
|
64
|
+
Args:
|
|
65
|
+
mode: 比对模式
|
|
66
|
+
Returns: 颜色说明
|
|
67
|
+
"""
|
|
68
|
+
return {
|
|
69
|
+
color.hex_value: color.get_info_by_mode(mode) for color in NodeColors if color.get_info_by_mode(mode)
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def get_node_error_status(mode, value):
|
|
74
|
+
"""
|
|
75
|
+
判断精度数据比对指标是否大于基准值
|
|
76
|
+
Args:
|
|
77
|
+
mode: 比对模式
|
|
78
|
+
value: 精度数据比对指标
|
|
79
|
+
Returns: bool
|
|
80
|
+
"""
|
|
81
|
+
info = NodeColors.ORANGE_1.get_info_by_mode(mode)
|
|
82
|
+
if info and GraphConst.VALUE in info:
|
|
83
|
+
value_range = info[GraphConst.VALUE]
|
|
84
|
+
return value > value_range[0]
|
|
85
|
+
return False
|
|
86
|
+
|
|
87
|
+
def get_info_by_mode(self, mode):
|
|
88
|
+
if isinstance(self.mode_info, dict):
|
|
89
|
+
# 检查是否是模式特定的信息
|
|
90
|
+
if isinstance(next(iter(self.mode_info.values())), dict):
|
|
91
|
+
return self.mode_info.get(mode, {})
|
|
92
|
+
else:
|
|
93
|
+
# 所有模式共享相同的信息
|
|
94
|
+
return self.mode_info
|
|
95
|
+
return {}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from enum import Enum
|
|
17
|
+
import re
|
|
18
|
+
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class NodeOp(Enum):
|
|
22
|
+
module = 0
|
|
23
|
+
function_api = 1
|
|
24
|
+
api_collection = 9
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def get_node_op(node_name: str):
|
|
29
|
+
"""
|
|
30
|
+
基于代表节点的字符串,解析节点种类
|
|
31
|
+
"""
|
|
32
|
+
for op in NodeOp:
|
|
33
|
+
index = op.value
|
|
34
|
+
if index < 0 or index >= len(op_patterns):
|
|
35
|
+
raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match")
|
|
36
|
+
pattern = op_patterns[index]
|
|
37
|
+
if re.match(pattern, node_name):
|
|
38
|
+
return op
|
|
39
|
+
raise Exception(f"Cannot parse node_name {node_name} into NodeOp")
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
import json
|
|
19
|
+
from msprobe.core.common.file_utils import (FileOpen, check_file_type, create_directory, FileChecker,
|
|
20
|
+
check_file_or_directory_path)
|
|
21
|
+
from msprobe.core.common.const import FileCheckConst, Const
|
|
22
|
+
from msprobe.core.common.utils import CompareException
|
|
23
|
+
from msprobe.core.overflow_check.checker import AnomalyDetector
|
|
24
|
+
from msprobe.visualization.compare.graph_comparator import GraphComparator
|
|
25
|
+
from msprobe.visualization.utils import GraphConst, check_directory_content
|
|
26
|
+
from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig
|
|
27
|
+
from msprobe.core.common.log import logger
|
|
28
|
+
from msprobe.visualization.graph.node_colors import NodeColors
|
|
29
|
+
from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
|
|
30
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents
|
|
31
|
+
from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
|
|
32
|
+
|
|
33
|
+
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _compare_graph(input_param, args):
|
|
37
|
+
logger.info('Start building model graphs...')
|
|
38
|
+
# 对两个数据进行构图
|
|
39
|
+
dump_path_n = input_param.get('npu_path')
|
|
40
|
+
dump_path_b = input_param.get('bench_path')
|
|
41
|
+
construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE),
|
|
42
|
+
FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
43
|
+
construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE),
|
|
44
|
+
FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
45
|
+
data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
46
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
47
|
+
data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
48
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
49
|
+
stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
50
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
51
|
+
stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
52
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
53
|
+
graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack)
|
|
54
|
+
graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack)
|
|
55
|
+
logger.info('Model graphs built successfully, start Comparing graphs...')
|
|
56
|
+
# 基于graph、stack和data进行比较
|
|
57
|
+
dump_path_param = {
|
|
58
|
+
'npu_json_path': data_path_n,
|
|
59
|
+
'bench_json_path': data_path_b,
|
|
60
|
+
'stack_json_path': stack_path_n,
|
|
61
|
+
'is_print_compare_log': input_param.get("is_print_compare_log", True)
|
|
62
|
+
}
|
|
63
|
+
mapping_dict = None
|
|
64
|
+
if args.layer_mapping:
|
|
65
|
+
yaml_path = FileChecker(args.layer_mapping, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
66
|
+
try:
|
|
67
|
+
mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, yaml_path)
|
|
68
|
+
except Exception:
|
|
69
|
+
logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.')
|
|
70
|
+
graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, mapping_dict=mapping_dict)
|
|
71
|
+
graph_comparator.compare()
|
|
72
|
+
micro_steps = graph_n.paging_by_micro_step(graph_b)
|
|
73
|
+
# 开启溢出检测
|
|
74
|
+
if args.overflow_check:
|
|
75
|
+
graph_n.overflow_check()
|
|
76
|
+
graph_b.overflow_check()
|
|
77
|
+
|
|
78
|
+
return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps,
|
|
82
|
+
output_file_name=f'compare_{current_time}.vis'):
|
|
83
|
+
create_directory(args.output_path)
|
|
84
|
+
output_path = os.path.join(args.output_path, output_file_name)
|
|
85
|
+
task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
|
|
86
|
+
export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
|
|
87
|
+
NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
|
|
88
|
+
args.overflow_check)
|
|
89
|
+
GraphBuilder.to_json(output_path, export_config)
|
|
90
|
+
logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}')
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _build_graph(dump_path, args):
|
|
94
|
+
logger.info('Start building model graph...')
|
|
95
|
+
construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
|
|
96
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
97
|
+
data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
98
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
99
|
+
stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
100
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
101
|
+
graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
|
|
102
|
+
micro_steps = graph.paging_by_micro_step()
|
|
103
|
+
# 开启溢出检测
|
|
104
|
+
if args.overflow_check:
|
|
105
|
+
graph.overflow_check()
|
|
106
|
+
return BuildGraphResult(graph, micro_steps)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _export_build_graph_result(out_path, graph, micro_steps, overflow_check,
|
|
110
|
+
output_file_name=f'build_{current_time}.vis'):
|
|
111
|
+
create_directory(out_path)
|
|
112
|
+
output_path = os.path.join(out_path, output_file_name)
|
|
113
|
+
GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check))
|
|
114
|
+
logger.info(f'Model graph built successfully, the result file is saved in {output_path}')
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _compare_graph_ranks(input_param, args, step=None):
|
|
118
|
+
dump_rank_n = input_param.get('npu_path')
|
|
119
|
+
dump_rank_b = input_param.get('bench_path')
|
|
120
|
+
npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK))
|
|
121
|
+
bench_ranks = sorted(check_and_return_dir_contents(dump_rank_b, Const.RANK))
|
|
122
|
+
if npu_ranks != bench_ranks:
|
|
123
|
+
logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
|
|
124
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
125
|
+
compare_graph_results = []
|
|
126
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
127
|
+
logger.info(f'Start processing data for {nr}...')
|
|
128
|
+
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
129
|
+
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
130
|
+
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
131
|
+
result = _compare_graph(input_param, args)
|
|
132
|
+
result.output_file_name = output_file_name
|
|
133
|
+
if nr != Const.RANK:
|
|
134
|
+
try:
|
|
135
|
+
result.rank = int(nr.replace(Const.RANK, ""))
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
138
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
139
|
+
# 暂存所有rank的graph,用于匹配rank间的分布式节点
|
|
140
|
+
compare_graph_results.append(result)
|
|
141
|
+
|
|
142
|
+
# 匹配rank间的分布式节点
|
|
143
|
+
if len(compare_graph_results) > 1:
|
|
144
|
+
DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
|
|
145
|
+
args.overflow_check).distributed_match()
|
|
146
|
+
DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
|
|
147
|
+
args.overflow_check).distributed_match()
|
|
148
|
+
|
|
149
|
+
for result in compare_graph_results:
|
|
150
|
+
_export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator,
|
|
151
|
+
result.micro_steps, output_file_name=result.output_file_name)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _compare_graph_steps(input_param, args):
|
|
155
|
+
dump_step_n = input_param.get('npu_path')
|
|
156
|
+
dump_step_b = input_param.get('bench_path')
|
|
157
|
+
|
|
158
|
+
npu_steps = sorted(check_and_return_dir_contents(dump_step_n, Const.STEP))
|
|
159
|
+
bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
|
|
160
|
+
|
|
161
|
+
if npu_steps != bench_steps:
|
|
162
|
+
logger.error('The number of steps in the two runs are different. Unable to match the steps.')
|
|
163
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
164
|
+
|
|
165
|
+
for folder_step in npu_steps:
|
|
166
|
+
logger.info(f'Start processing data for {folder_step}...')
|
|
167
|
+
input_param['npu_path'] = os.path.join(dump_step_n, folder_step)
|
|
168
|
+
input_param['bench_path'] = os.path.join(dump_step_b, folder_step)
|
|
169
|
+
|
|
170
|
+
_compare_graph_ranks(input_param, args, step=folder_step)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
174
|
+
ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
|
|
175
|
+
build_graph_results = []
|
|
176
|
+
for rank in ranks:
|
|
177
|
+
logger.info(f'Start processing data for {rank}...')
|
|
178
|
+
dump_path = os.path.join(dump_ranks_path, rank)
|
|
179
|
+
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
180
|
+
result = _build_graph(dump_path, args)
|
|
181
|
+
result.output_file_name = output_file_name
|
|
182
|
+
if rank != Const.RANK:
|
|
183
|
+
try:
|
|
184
|
+
result.rank = int(rank.replace(Const.RANK, ""))
|
|
185
|
+
except Exception as e:
|
|
186
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
187
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
188
|
+
build_graph_results.append(result)
|
|
189
|
+
|
|
190
|
+
if len(build_graph_results) > 1:
|
|
191
|
+
DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
|
|
192
|
+
args.overflow_check).distributed_match()
|
|
193
|
+
|
|
194
|
+
for result in build_graph_results:
|
|
195
|
+
_export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check,
|
|
196
|
+
result.output_file_name)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _build_graph_steps(dump_steps_path, args):
|
|
200
|
+
steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
|
|
201
|
+
for step in steps:
|
|
202
|
+
logger.info(f'Start processing data for {step}...')
|
|
203
|
+
dump_ranks_path = os.path.join(dump_steps_path, step)
|
|
204
|
+
_build_graph_ranks(dump_ranks_path, args, step)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _graph_service_parser(parser):
|
|
208
|
+
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
209
|
+
help="<Required> The compare input path, a dict json.", required=True)
|
|
210
|
+
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
211
|
+
help="<Required> The compare task result out path.", required=True)
|
|
212
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
213
|
+
help="<Optional> The layer mapping file path.", required=False)
|
|
214
|
+
parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
|
|
215
|
+
help="<Optional> whether open overflow_check for graph.", required=False)
|
|
216
|
+
parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
|
|
217
|
+
help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
|
|
218
|
+
parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true",
|
|
219
|
+
help="<Optional> Whether to use complete stack information.", required=False)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _graph_service_command(args):
|
|
223
|
+
with FileOpen(args.input_path, "r") as file:
|
|
224
|
+
input_param = json.load(file)
|
|
225
|
+
npu_path = input_param.get("npu_path")
|
|
226
|
+
bench_path = input_param.get("bench_path")
|
|
227
|
+
check_file_or_directory_path(npu_path, isdir=True)
|
|
228
|
+
if bench_path:
|
|
229
|
+
check_file_or_directory_path(bench_path, isdir=True)
|
|
230
|
+
if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
|
|
231
|
+
content = check_directory_content(npu_path)
|
|
232
|
+
if content == GraphConst.RANKS:
|
|
233
|
+
_build_graph_ranks(npu_path, args)
|
|
234
|
+
elif content == GraphConst.STEPS:
|
|
235
|
+
_build_graph_steps(npu_path, args)
|
|
236
|
+
else:
|
|
237
|
+
result = _build_graph(npu_path, args)
|
|
238
|
+
_export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check)
|
|
239
|
+
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
240
|
+
content_n = check_directory_content(npu_path)
|
|
241
|
+
content_b = check_directory_content(bench_path)
|
|
242
|
+
if content_n != content_b:
|
|
243
|
+
raise ValueError('The directory structures of npu_path and bench_path are inconsistent.')
|
|
244
|
+
if content_n == GraphConst.RANKS:
|
|
245
|
+
_compare_graph_ranks(input_param, args)
|
|
246
|
+
elif content_n == GraphConst.STEPS:
|
|
247
|
+
_compare_graph_steps(input_param, args)
|
|
248
|
+
else:
|
|
249
|
+
result = _compare_graph(input_param, args)
|
|
250
|
+
_export_compare_graph_result(args, [result.graph_n, result.graph_b],
|
|
251
|
+
result.graph_comparator, result.micro_steps)
|
|
252
|
+
else:
|
|
253
|
+
logger.error("The npu_path or bench_path should be a folder.")
|
|
254
|
+
raise CompareException(CompareException.INVALID_COMPARE_MODE)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _pt_graph_service_parser(parser):
|
|
258
|
+
_graph_service_parser(parser)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _pt_graph_service_command(args):
|
|
262
|
+
_graph_service_command(args)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _ms_graph_service_parser(parser):
|
|
266
|
+
_graph_service_parser(parser)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _ms_graph_service_command(args):
|
|
270
|
+
_graph_service_command(args)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class CompareGraphResult:
|
|
274
|
+
def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''):
|
|
275
|
+
self.graph_n = graph_n
|
|
276
|
+
self.graph_b = graph_b
|
|
277
|
+
self.graph_comparator = graph_comparator
|
|
278
|
+
self.micro_steps = micro_steps
|
|
279
|
+
self.rank = rank
|
|
280
|
+
self.output_file_name = output_file_name
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class BuildGraphResult:
|
|
284
|
+
def __init__(self, graph, micro_steps, rank=0, output_file_name=''):
|
|
285
|
+
self.graph = graph
|
|
286
|
+
self.micro_steps = micro_steps
|
|
287
|
+
self.rank = rank
|
|
288
|
+
self.output_file_name = output_file_name
|