mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -146,7 +146,7 @@ class Graph:
|
|
|
146
146
|
"""
|
|
147
147
|
return self.node_map.get(node_id, None)
|
|
148
148
|
|
|
149
|
-
def to_dict(self):
|
|
149
|
+
def to_dict(self, compare_mode=None):
|
|
150
150
|
"""
|
|
151
151
|
用于数据输出
|
|
152
152
|
"""
|
|
@@ -155,7 +155,7 @@ class Graph:
|
|
|
155
155
|
result[GraphConst.JSON_DATA_KEY] = self.data_path
|
|
156
156
|
result[GraphConst.JSON_NODE_KEY] = {}
|
|
157
157
|
for node_id in self.node_map:
|
|
158
|
-
info = self.node_map.get(node_id).to_dict()
|
|
158
|
+
info = self.node_map.get(node_id).to_dict(compare_mode)
|
|
159
159
|
result[GraphConst.JSON_NODE_KEY][node_id] = info
|
|
160
160
|
return result
|
|
161
161
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,83 +15,93 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import time
|
|
18
|
-
import
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from multiprocessing import cpu_count, Pool
|
|
19
20
|
from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
|
|
20
21
|
check_file_or_directory_path, load_json)
|
|
21
22
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
22
|
-
from msprobe.core.common.utils import CompareException
|
|
23
|
-
from msprobe.core.overflow_check.checker import AnomalyDetector
|
|
23
|
+
from msprobe.core.common.utils import CompareException, get_dump_mode
|
|
24
24
|
from msprobe.visualization.compare.graph_comparator import GraphComparator
|
|
25
|
-
from msprobe.visualization.utils import GraphConst, check_directory_content
|
|
26
|
-
from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig
|
|
25
|
+
from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs
|
|
26
|
+
from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo
|
|
27
27
|
from msprobe.core.common.log import logger
|
|
28
28
|
from msprobe.visualization.graph.node_colors import NodeColors
|
|
29
29
|
from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
|
|
30
30
|
from msprobe.core.compare.utils import check_and_return_dir_contents
|
|
31
|
+
from msprobe.core.common.utils import detect_framework_by_dump_json
|
|
31
32
|
from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
|
|
32
33
|
|
|
33
34
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def _compare_graph(input_param, args):
|
|
37
|
-
logger.info('Start building model graphs...')
|
|
38
|
-
# 对两个数据进行构图
|
|
39
|
-
dump_path_n = input_param.get('npu_path')
|
|
40
|
-
dump_path_b = input_param.get('bench_path')
|
|
41
|
-
construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE),
|
|
42
|
-
FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
43
|
-
construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE),
|
|
44
|
-
FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
45
|
-
data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
46
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
47
|
-
data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
48
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
49
|
-
stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
50
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
51
|
-
stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
52
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
53
|
-
graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack)
|
|
54
|
-
graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack)
|
|
55
|
-
logger.info('Model graphs built successfully, start Comparing graphs...')
|
|
56
|
-
# 基于graph、stack和data进行比较
|
|
37
|
+
def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args):
|
|
57
38
|
dump_path_param = {
|
|
58
|
-
'npu_json_path':
|
|
59
|
-
'bench_json_path':
|
|
60
|
-
'stack_json_path':
|
|
39
|
+
'npu_json_path': graph_n.data_path,
|
|
40
|
+
'bench_json_path': graph_b.data_path,
|
|
41
|
+
'stack_json_path': graph_n.stack_path,
|
|
61
42
|
'is_print_compare_log': input_param.get("is_print_compare_log", True)
|
|
62
43
|
}
|
|
63
|
-
mapping_dict =
|
|
44
|
+
mapping_dict = {}
|
|
64
45
|
if args.layer_mapping:
|
|
65
|
-
yaml_path = FileChecker(args.layer_mapping, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
66
46
|
try:
|
|
67
|
-
mapping_dict = generate_api_mapping_by_layer_mapping(
|
|
47
|
+
mapping_dict = generate_api_mapping_by_layer_mapping(graph_n.data_path, graph_b.data_path,
|
|
48
|
+
args.layer_mapping)
|
|
68
49
|
except Exception:
|
|
69
50
|
logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.')
|
|
70
|
-
|
|
51
|
+
is_cross_framework = detect_framework_by_dump_json(graph_n.data_path) != \
|
|
52
|
+
detect_framework_by_dump_json(graph_b.data_path)
|
|
53
|
+
if is_cross_framework and not args.layer_mapping:
|
|
54
|
+
logger.error('The cross_frame graph comparison failed. '
|
|
55
|
+
'Please specify -lm or --layer_mapping when performing cross_frame graph comparison.')
|
|
56
|
+
raise CompareException(CompareException.CROSS_FRAME_ERROR)
|
|
57
|
+
|
|
58
|
+
graph_comparator = GraphComparator([graph_n.graph, graph_b.graph], dump_path_param, args, is_cross_framework,
|
|
59
|
+
mapping_dict=mapping_dict)
|
|
71
60
|
graph_comparator.compare()
|
|
72
|
-
|
|
61
|
+
return graph_comparator
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _compare_graph_result(input_param, args):
|
|
65
|
+
logger.info('Start building model graphs...')
|
|
66
|
+
# 对两个数据进行构图
|
|
67
|
+
graph_n = _build_graph_info(input_param.get('npu_path'), args)
|
|
68
|
+
graph_b = _build_graph_info(input_param.get('bench_path'), args)
|
|
69
|
+
logger.info('Model graphs built successfully, start Comparing graphs...')
|
|
70
|
+
# 基于graph、stack和data进行比较
|
|
71
|
+
graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
|
|
72
|
+
# 增加micro step标记
|
|
73
|
+
micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph)
|
|
73
74
|
# 开启溢出检测
|
|
74
75
|
if args.overflow_check:
|
|
75
|
-
graph_n.overflow_check()
|
|
76
|
-
graph_b.overflow_check()
|
|
76
|
+
graph_n.graph.overflow_check()
|
|
77
|
+
graph_b.graph.overflow_check()
|
|
77
78
|
|
|
78
|
-
return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps)
|
|
79
|
+
return CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
|
|
79
80
|
|
|
80
81
|
|
|
81
|
-
def _export_compare_graph_result(args,
|
|
82
|
-
|
|
83
|
-
|
|
82
|
+
def _export_compare_graph_result(args, result):
|
|
83
|
+
graphs = [result.graph_n, result.graph_b]
|
|
84
|
+
graph_comparator = result.graph_comparator
|
|
85
|
+
micro_steps = result.micro_steps
|
|
86
|
+
output_file_name = result.output_file_name
|
|
87
|
+
if not output_file_name:
|
|
88
|
+
output_file_name = f'compare_{current_time}.vis'
|
|
89
|
+
logger.info(f'Start exporting compare graph result, file name: {output_file_name}...')
|
|
84
90
|
output_path = os.path.join(args.output_path, output_file_name)
|
|
85
91
|
task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
|
|
86
92
|
export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
|
|
87
93
|
NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
|
|
88
|
-
args.overflow_check)
|
|
89
|
-
|
|
90
|
-
|
|
94
|
+
args.overflow_check, graph_comparator.ma.compare_mode)
|
|
95
|
+
try:
|
|
96
|
+
GraphBuilder.to_json(output_path, export_config)
|
|
97
|
+
logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_path}')
|
|
98
|
+
return ''
|
|
99
|
+
except RuntimeError as e:
|
|
100
|
+
logger.error(f'Failed to export compare graph result, file: {output_file_name}, error: {e}')
|
|
101
|
+
return output_file_name
|
|
91
102
|
|
|
92
103
|
|
|
93
|
-
def
|
|
94
|
-
logger.info('Start building model graph...')
|
|
104
|
+
def _build_graph_info(dump_path, args):
|
|
95
105
|
construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
|
|
96
106
|
FileCheckConst.READ_ABLE).common_check()
|
|
97
107
|
data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
@@ -99,6 +109,13 @@ def _build_graph(dump_path, args):
|
|
|
99
109
|
stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
100
110
|
FileCheckConst.READ_ABLE).common_check()
|
|
101
111
|
graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
|
|
112
|
+
return GraphInfo(graph, construct_path, data_path, stack_path)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _build_graph_result(dump_path, args):
|
|
116
|
+
logger.info('Start building model graphs...')
|
|
117
|
+
graph = _build_graph_info(dump_path, args).graph
|
|
118
|
+
# 增加micro step标记
|
|
102
119
|
micro_steps = graph.paging_by_micro_step()
|
|
103
120
|
# 开启溢出检测
|
|
104
121
|
if args.overflow_check:
|
|
@@ -106,15 +123,128 @@ def _build_graph(dump_path, args):
|
|
|
106
123
|
return BuildGraphResult(graph, micro_steps)
|
|
107
124
|
|
|
108
125
|
|
|
109
|
-
def
|
|
110
|
-
|
|
111
|
-
|
|
126
|
+
def _run_build_graph_compare(input_param, args, nr, br):
|
|
127
|
+
logger.info(f'Start building graph for {nr}...')
|
|
128
|
+
graph_n = _build_graph_info(input_param.get('npu_path'), args)
|
|
129
|
+
graph_b = _build_graph_info(input_param.get('bench_path'), args)
|
|
130
|
+
logger.info(f'Building graph for {nr} finished.')
|
|
131
|
+
return BuildGraphTaskInfo(graph_n, graph_b, nr, br, current_time)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _run_build_graph_single(dump_ranks_path, rank, step, args):
|
|
135
|
+
logger.info(f'Start building graph for {rank}...')
|
|
136
|
+
dump_path = os.path.join(dump_ranks_path, rank)
|
|
137
|
+
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
138
|
+
result = _build_graph_result(dump_path, args)
|
|
139
|
+
result.output_file_name = output_file_name
|
|
140
|
+
if rank != Const.RANK:
|
|
141
|
+
try:
|
|
142
|
+
result.rank = int(rank.replace(Const.RANK, ""))
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
145
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
146
|
+
logger.info(f'Building graph for step: {step}, rank: {rank} finished.')
|
|
147
|
+
return result
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _run_graph_compare(graph_task_info, input_param, args, output_file_name):
|
|
151
|
+
logger.info(f'Start comparing data for {graph_task_info.npu_rank}...')
|
|
152
|
+
graph_n = graph_task_info.graph_info_n
|
|
153
|
+
graph_b = graph_task_info.graph_info_b
|
|
154
|
+
nr = graph_task_info.npu_rank
|
|
155
|
+
graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
|
|
156
|
+
micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph)
|
|
157
|
+
# 开启溢出检测
|
|
158
|
+
if args.overflow_check:
|
|
159
|
+
graph_n.graph.overflow_check()
|
|
160
|
+
graph_b.graph.overflow_check()
|
|
161
|
+
graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
|
|
162
|
+
graph_result.output_file_name = output_file_name
|
|
163
|
+
if nr != Const.RANK:
|
|
164
|
+
try:
|
|
165
|
+
graph_result.rank = int(nr.replace(Const.RANK, ""))
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
168
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
169
|
+
logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.')
|
|
170
|
+
return graph_result
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _export_build_graph_result(args, result):
|
|
174
|
+
out_path = args.output_path
|
|
175
|
+
graph = result.graph
|
|
176
|
+
micro_steps = result.micro_steps
|
|
177
|
+
overflow_check = args.overflow_check
|
|
178
|
+
output_file_name = result.output_file_name
|
|
179
|
+
if not output_file_name:
|
|
180
|
+
output_file_name = f'build_{current_time}.vis'
|
|
181
|
+
logger.info(f'Start exporting graph for {output_file_name}...')
|
|
112
182
|
output_path = os.path.join(out_path, output_file_name)
|
|
113
|
-
|
|
114
|
-
|
|
183
|
+
try:
|
|
184
|
+
GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps,
|
|
185
|
+
overflow_check=overflow_check))
|
|
186
|
+
logger.info(f'Model graph exported successfully, the result file is saved in {output_path}')
|
|
187
|
+
return None
|
|
188
|
+
except RuntimeError as e:
|
|
189
|
+
logger.error(f'Failed to export model graph, file: {output_file_name}, error: {e}')
|
|
190
|
+
return output_file_name
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
194
|
+
dump_rank_n = input_param.get('npu_path')
|
|
195
|
+
dump_rank_b = input_param.get('bench_path')
|
|
196
|
+
has_real_data = False
|
|
197
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
198
|
+
dump_path_param = {
|
|
199
|
+
'npu_json_path': FileChecker(os.path.join(dump_rank_n, nr, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
200
|
+
FileCheckConst.READ_ABLE).common_check(),
|
|
201
|
+
'bench_json_path': FileChecker(os.path.join(dump_rank_b, br, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
202
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
203
|
+
}
|
|
204
|
+
has_real_data |= get_dump_mode(dump_path_param) == Const.ALL
|
|
205
|
+
return has_real_data
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _mp_compare(input_param, serializable_args, output_file_name, nr, br):
|
|
209
|
+
graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br)
|
|
210
|
+
return _run_graph_compare(graph_task_info, input_param, serializable_args, output_file_name)
|
|
115
211
|
|
|
116
212
|
|
|
117
213
|
def _compare_graph_ranks(input_param, args, step=None):
|
|
214
|
+
with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
|
|
215
|
+
def err_call(err):
|
|
216
|
+
logger.error(f'Error occurred while comparing graph ranks: {err}')
|
|
217
|
+
try:
|
|
218
|
+
pool.close()
|
|
219
|
+
except OSError as e:
|
|
220
|
+
logger.error(f'Error occurred while terminating the pool: {e}')
|
|
221
|
+
|
|
222
|
+
serializable_args = SerializableArgs(args)
|
|
223
|
+
# 暂存所有rank的graph,用于匹配rank间的分布式节点
|
|
224
|
+
compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call)
|
|
225
|
+
|
|
226
|
+
# 匹配rank间的分布式节点
|
|
227
|
+
if len(compare_graph_results) > 1:
|
|
228
|
+
DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
|
|
229
|
+
args.overflow_check).distributed_match()
|
|
230
|
+
DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
|
|
231
|
+
args.overflow_check).distributed_match()
|
|
232
|
+
|
|
233
|
+
export_res_task_list = []
|
|
234
|
+
create_directory(args.output_path)
|
|
235
|
+
for result in compare_graph_results:
|
|
236
|
+
export_res_task_list.append(pool.apply_async(_export_compare_graph_result,
|
|
237
|
+
args=(serializable_args, result),
|
|
238
|
+
error_callback=err_call))
|
|
239
|
+
export_res_list = [res.get() for res in export_res_task_list]
|
|
240
|
+
if any(export_res_list):
|
|
241
|
+
failed_names = list(filter(lambda x: x, export_res_list))
|
|
242
|
+
logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.')
|
|
243
|
+
else:
|
|
244
|
+
logger.info('Successfully exported compare graph results.')
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call):
|
|
118
248
|
dump_rank_n = input_param.get('npu_path')
|
|
119
249
|
dump_rank_b = input_param.get('bench_path')
|
|
120
250
|
npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK))
|
|
@@ -123,32 +253,33 @@ def _compare_graph_ranks(input_param, args, step=None):
|
|
|
123
253
|
logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
|
|
124
254
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
125
255
|
compare_graph_results = []
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
256
|
+
if is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
257
|
+
mp_task_dict = {}
|
|
258
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
259
|
+
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
260
|
+
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
261
|
+
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
262
|
+
input_param_copy = deepcopy(input_param)
|
|
263
|
+
mp_task_dict[output_file_name] = pool.apply_async(_run_build_graph_compare,
|
|
264
|
+
args=(input_param_copy, serializable_args, nr, br),
|
|
265
|
+
error_callback=err_call)
|
|
266
|
+
|
|
267
|
+
mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()}
|
|
268
|
+
for output_file_name, mp_res in mp_res_dict.items():
|
|
269
|
+
compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args, output_file_name))
|
|
270
|
+
else:
|
|
271
|
+
compare_graph_tasks = []
|
|
272
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
273
|
+
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
274
|
+
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
275
|
+
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
276
|
+
input_param_copy = deepcopy(input_param)
|
|
277
|
+
compare_graph_tasks.append(pool.apply_async(_mp_compare,
|
|
278
|
+
args=(input_param_copy, serializable_args, output_file_name, nr,
|
|
279
|
+
br),
|
|
280
|
+
error_callback=err_call))
|
|
281
|
+
compare_graph_results = [task.get() for task in compare_graph_tasks]
|
|
282
|
+
return compare_graph_results
|
|
152
283
|
|
|
153
284
|
|
|
154
285
|
def _compare_graph_steps(input_param, args):
|
|
@@ -172,28 +303,39 @@ def _compare_graph_steps(input_param, args):
|
|
|
172
303
|
|
|
173
304
|
def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
174
305
|
ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
180
|
-
result = _build_graph(dump_path, args)
|
|
181
|
-
result.output_file_name = output_file_name
|
|
182
|
-
if rank != Const.RANK:
|
|
306
|
+
serializable_args = SerializableArgs(args)
|
|
307
|
+
with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
|
|
308
|
+
def err_call(err):
|
|
309
|
+
logger.error(f'Error occurred while comparing graph ranks: {err}')
|
|
183
310
|
try:
|
|
184
|
-
|
|
185
|
-
except
|
|
186
|
-
logger.error('
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
311
|
+
pool.close()
|
|
312
|
+
except OSError as e:
|
|
313
|
+
logger.error(f'Error occurred while terminating the pool: {e}')
|
|
314
|
+
|
|
315
|
+
build_graph_tasks = []
|
|
316
|
+
for rank in ranks:
|
|
317
|
+
build_graph_tasks.append(pool.apply_async(_run_build_graph_single,
|
|
318
|
+
args=(dump_ranks_path, rank, step, serializable_args),
|
|
319
|
+
error_callback=err_call))
|
|
320
|
+
build_graph_results = [task.get() for task in build_graph_tasks]
|
|
321
|
+
|
|
322
|
+
if len(build_graph_results) > 1:
|
|
323
|
+
DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
|
|
324
|
+
args.overflow_check).distributed_match()
|
|
325
|
+
|
|
326
|
+
create_directory(args.output_path)
|
|
327
|
+
export_build_graph_tasks = []
|
|
328
|
+
for result in build_graph_results:
|
|
329
|
+
export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result,
|
|
330
|
+
args=(serializable_args, result),
|
|
331
|
+
error_callback=err_call))
|
|
332
|
+
export_build_graph_result = [task.get() for task in export_build_graph_tasks]
|
|
333
|
+
if any(export_build_graph_result):
|
|
334
|
+
failed_names = list(filter(lambda x: x, export_build_graph_result))
|
|
335
|
+
logger.error(f'Unable to export build graph results: {failed_names}.')
|
|
336
|
+
else:
|
|
337
|
+
logger.info(f'Successfully exported build graph results.')
|
|
193
338
|
|
|
194
|
-
for result in build_graph_results:
|
|
195
|
-
_export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check,
|
|
196
|
-
result.output_file_name)
|
|
197
339
|
|
|
198
340
|
|
|
199
341
|
def _build_graph_steps(dump_steps_path, args):
|
|
@@ -209,7 +351,7 @@ def _graph_service_parser(parser):
|
|
|
209
351
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
210
352
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
211
353
|
help="<Required> The compare task result out path.", required=True)
|
|
212
|
-
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
354
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
213
355
|
help="<Optional> The layer mapping file path.", required=False)
|
|
214
356
|
parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
|
|
215
357
|
help="<Optional> whether open overflow_check for graph.", required=False)
|
|
@@ -233,8 +375,11 @@ def _graph_service_command(args):
|
|
|
233
375
|
elif content == GraphConst.STEPS:
|
|
234
376
|
_build_graph_steps(npu_path, args)
|
|
235
377
|
else:
|
|
236
|
-
result =
|
|
237
|
-
|
|
378
|
+
result = _build_graph_result(npu_path, args)
|
|
379
|
+
create_directory(args.output_path)
|
|
380
|
+
file_name = _export_build_graph_result(args, result)
|
|
381
|
+
if file_name:
|
|
382
|
+
logger.error('Failed to export model build graph.')
|
|
238
383
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
239
384
|
content_n = check_directory_content(npu_path)
|
|
240
385
|
content_b = check_directory_content(bench_path)
|
|
@@ -245,9 +390,11 @@ def _graph_service_command(args):
|
|
|
245
390
|
elif content_n == GraphConst.STEPS:
|
|
246
391
|
_compare_graph_steps(input_param, args)
|
|
247
392
|
else:
|
|
248
|
-
result =
|
|
249
|
-
|
|
250
|
-
|
|
393
|
+
result = _compare_graph_result(input_param, args)
|
|
394
|
+
create_directory(args.output_path)
|
|
395
|
+
file_name = _export_compare_graph_result(args, result)
|
|
396
|
+
if file_name:
|
|
397
|
+
logger.error('Failed to export model compare graph.')
|
|
251
398
|
else:
|
|
252
399
|
logger.error("The npu_path or bench_path should be a folder.")
|
|
253
400
|
raise CompareException(CompareException.INVALID_COMPARE_MODE)
|
msprobe/visualization/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
import json
|
|
19
|
+
import pickle
|
|
19
20
|
from msprobe.core.common.file_utils import FileOpen
|
|
20
21
|
from msprobe.core.common.const import CompareConst, Const
|
|
21
|
-
from msprobe.core.
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def load_json_file(file_path):
|
|
@@ -42,15 +43,6 @@ def load_data_json_file(file_path):
|
|
|
42
43
|
return load_json_file(file_path).get(GraphConst.DATA_KEY, {})
|
|
43
44
|
|
|
44
45
|
|
|
45
|
-
def get_csv_df(stack_mode, csv_data, compare_mode):
|
|
46
|
-
"""
|
|
47
|
-
调用acc接口写入csv
|
|
48
|
-
"""
|
|
49
|
-
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
50
|
-
mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=dump_mode)
|
|
51
|
-
return Comparator(mode_config).make_result_table(csv_data)
|
|
52
|
-
|
|
53
|
-
|
|
54
46
|
def str2float(percentage_str):
|
|
55
47
|
"""
|
|
56
48
|
百分比字符串转换转换为浮点型
|
|
@@ -192,3 +184,27 @@ class GraphConst:
|
|
|
192
184
|
OP = 'op'
|
|
193
185
|
PEER = 'peer'
|
|
194
186
|
GROUP_ID = 'group_id'
|
|
187
|
+
|
|
188
|
+
IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty',
|
|
189
|
+
'empty_strided'}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def is_serializable(obj):
|
|
193
|
+
"""
|
|
194
|
+
Check if an object is serializable
|
|
195
|
+
"""
|
|
196
|
+
try:
|
|
197
|
+
pickle.dumps(obj)
|
|
198
|
+
return True
|
|
199
|
+
except (pickle.PicklingError, AttributeError, TypeError):
|
|
200
|
+
return False
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error('Unexpected error occurred while pickling obj.')
|
|
203
|
+
raise RuntimeError('Unexpected error occurred while pickling obj.') from e
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class SerializableArgs:
|
|
207
|
+
def __init__(self, args):
|
|
208
|
+
for k, v in vars(args).items():
|
|
209
|
+
if is_serializable(v):
|
|
210
|
+
setattr(self, k, v)
|
|
@@ -1,106 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Copyright 2024 Huawei Technologies Co., Ltd
|
|
3
|
-
*
|
|
4
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
* you may not use this file except in compliance with the License.
|
|
6
|
-
* You may obtain a copy of the License at
|
|
7
|
-
*
|
|
8
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
*
|
|
10
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
* See the License for the specific language governing permissions and
|
|
14
|
-
* limitations under the License.
|
|
15
|
-
*/
|
|
16
|
-
|
|
17
|
-
#include "hook_dynamic_loader.h"
|
|
18
|
-
#include <sys/stat.h>
|
|
19
|
-
#include <cstdlib>
|
|
20
|
-
#include <cstring>
|
|
21
|
-
#include <pybind11/embed.h>
|
|
22
|
-
#include "utils/log_adapter.h"
|
|
23
|
-
|
|
24
|
-
namespace py = pybind11;
|
|
25
|
-
|
|
26
|
-
HookDynamicLoader &HookDynamicLoader::GetInstance() {
|
|
27
|
-
static HookDynamicLoader instance;
|
|
28
|
-
return instance;
|
|
29
|
-
}
|
|
30
|
-
|
|
31
|
-
bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionName) {
|
|
32
|
-
void *func = dlsym(handle, functionName.c_str());
|
|
33
|
-
if (!func) {
|
|
34
|
-
MS_LOG(WARNING) << "Could not load function: " << functionName << ", error: " << dlerror();
|
|
35
|
-
return false;
|
|
36
|
-
}
|
|
37
|
-
funcMap_[functionName] = func;
|
|
38
|
-
return true;
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
bool HookDynamicLoader::LoadLibrary() {
|
|
42
|
-
std::string msprobePath = "";
|
|
43
|
-
// 获取gil锁
|
|
44
|
-
py::gil_scoped_acquire acquire;
|
|
45
|
-
try {
|
|
46
|
-
py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
|
|
47
|
-
if (!py::hasattr(msprobeMod, "__file__")) {
|
|
48
|
-
MS_LOG(WARNING) << "Adump mod not found";
|
|
49
|
-
return false;
|
|
50
|
-
}
|
|
51
|
-
msprobePath = msprobeMod.attr("__file__").cast<std::string>();
|
|
52
|
-
} catch (const std::exception& e) {
|
|
53
|
-
MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
|
|
54
|
-
return false;
|
|
55
|
-
}
|
|
56
|
-
std::lock_guard<std::mutex> lock(mutex_);
|
|
57
|
-
if (handle_) {
|
|
58
|
-
MS_LOG(WARNING) << "Hook library already loaded!";
|
|
59
|
-
return false;
|
|
60
|
-
}
|
|
61
|
-
if (msprobePath == "") {
|
|
62
|
-
MS_LOG(WARNING) << "Adump path not loaded";
|
|
63
|
-
return false;
|
|
64
|
-
}
|
|
65
|
-
handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
|
66
|
-
if (!handle_) {
|
|
67
|
-
MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
|
|
68
|
-
return false;
|
|
69
|
-
}
|
|
70
|
-
|
|
71
|
-
for (const auto &functionName : functionList_) {
|
|
72
|
-
if (!loadFunction(handle_, functionName)) {
|
|
73
|
-
MS_LOG(WARNING) << "Failed to load adump function";
|
|
74
|
-
dlclose(handle_);
|
|
75
|
-
handle_ = nullptr;
|
|
76
|
-
return false;
|
|
77
|
-
}
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
MS_LOG(INFO) << "Hook library loaded successfully.";
|
|
81
|
-
return true;
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
bool HookDynamicLoader::UnloadLibrary() {
|
|
85
|
-
std::lock_guard<std::mutex> lock(mutex_);
|
|
86
|
-
if (!handle_) {
|
|
87
|
-
MS_LOG(WARNING) << "Hook library hasn't been loaded.";
|
|
88
|
-
return false;
|
|
89
|
-
}
|
|
90
|
-
|
|
91
|
-
dlclose(handle_);
|
|
92
|
-
handle_ = nullptr;
|
|
93
|
-
funcMap_.clear();
|
|
94
|
-
MS_LOG(INFO) << "Library unloaded successfully.";
|
|
95
|
-
return true;
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
void *HookDynamicLoader::GetHooker(const std::string &funcName) {
|
|
99
|
-
std::lock_guard<std::mutex> lock(mutex_);
|
|
100
|
-
auto iter = funcMap_.find(funcName);
|
|
101
|
-
if (iter == funcMap_.end()) {
|
|
102
|
-
MS_LOG(WARNING) << "Function not found: " << funcName;
|
|
103
|
-
return nullptr;
|
|
104
|
-
}
|
|
105
|
-
return iter->second;
|
|
106
|
-
}
|