mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -15,83 +15,93 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import time
|
|
18
|
-
import
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from multiprocessing import cpu_count, Pool
|
|
19
20
|
from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
|
|
20
21
|
check_file_or_directory_path, load_json)
|
|
21
22
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
22
|
-
from msprobe.core.common.utils import CompareException
|
|
23
|
-
from msprobe.core.overflow_check.checker import AnomalyDetector
|
|
23
|
+
from msprobe.core.common.utils import CompareException, get_dump_mode
|
|
24
24
|
from msprobe.visualization.compare.graph_comparator import GraphComparator
|
|
25
|
-
from msprobe.visualization.utils import GraphConst, check_directory_content
|
|
26
|
-
from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig
|
|
25
|
+
from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs
|
|
26
|
+
from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo
|
|
27
27
|
from msprobe.core.common.log import logger
|
|
28
28
|
from msprobe.visualization.graph.node_colors import NodeColors
|
|
29
29
|
from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping
|
|
30
30
|
from msprobe.core.compare.utils import check_and_return_dir_contents
|
|
31
|
+
from msprobe.core.common.utils import detect_framework_by_dump_json
|
|
31
32
|
from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
|
|
32
33
|
|
|
33
34
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def _compare_graph(input_param, args):
|
|
37
|
-
logger.info('Start building model graphs...')
|
|
38
|
-
# 对两个数据进行构图
|
|
39
|
-
dump_path_n = input_param.get('npu_path')
|
|
40
|
-
dump_path_b = input_param.get('bench_path')
|
|
41
|
-
construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE),
|
|
42
|
-
FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
43
|
-
construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE),
|
|
44
|
-
FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
45
|
-
data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
46
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
47
|
-
data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
48
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
49
|
-
stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
50
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
51
|
-
stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
52
|
-
FileCheckConst.READ_ABLE).common_check()
|
|
53
|
-
graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack)
|
|
54
|
-
graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack)
|
|
55
|
-
logger.info('Model graphs built successfully, start Comparing graphs...')
|
|
56
|
-
# 基于graph、stack和data进行比较
|
|
37
|
+
def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args):
|
|
57
38
|
dump_path_param = {
|
|
58
|
-
'npu_json_path':
|
|
59
|
-
'bench_json_path':
|
|
60
|
-
'stack_json_path':
|
|
39
|
+
'npu_json_path': graph_n.data_path,
|
|
40
|
+
'bench_json_path': graph_b.data_path,
|
|
41
|
+
'stack_json_path': graph_n.stack_path,
|
|
61
42
|
'is_print_compare_log': input_param.get("is_print_compare_log", True)
|
|
62
43
|
}
|
|
63
|
-
mapping_dict =
|
|
44
|
+
mapping_dict = {}
|
|
64
45
|
if args.layer_mapping:
|
|
65
|
-
yaml_path = FileChecker(args.layer_mapping, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
66
46
|
try:
|
|
67
|
-
mapping_dict = generate_api_mapping_by_layer_mapping(
|
|
47
|
+
mapping_dict = generate_api_mapping_by_layer_mapping(graph_n.data_path, graph_b.data_path,
|
|
48
|
+
args.layer_mapping)
|
|
68
49
|
except Exception:
|
|
69
50
|
logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.')
|
|
70
|
-
|
|
51
|
+
is_cross_framework = detect_framework_by_dump_json(graph_n.data_path) != \
|
|
52
|
+
detect_framework_by_dump_json(graph_b.data_path)
|
|
53
|
+
if is_cross_framework and not args.layer_mapping:
|
|
54
|
+
logger.error('The cross_frame graph comparison failed. '
|
|
55
|
+
'Please specify -lm or --layer_mapping when performing cross_frame graph comparison.')
|
|
56
|
+
raise CompareException(CompareException.CROSS_FRAME_ERROR)
|
|
57
|
+
|
|
58
|
+
graph_comparator = GraphComparator([graph_n.graph, graph_b.graph], dump_path_param, args, is_cross_framework,
|
|
59
|
+
mapping_dict=mapping_dict)
|
|
71
60
|
graph_comparator.compare()
|
|
72
|
-
|
|
61
|
+
return graph_comparator
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _compare_graph_result(input_param, args):
|
|
65
|
+
logger.info('Start building model graphs...')
|
|
66
|
+
# 对两个数据进行构图
|
|
67
|
+
graph_n = _build_graph_info(input_param.get('npu_path'), args)
|
|
68
|
+
graph_b = _build_graph_info(input_param.get('bench_path'), args)
|
|
69
|
+
logger.info('Model graphs built successfully, start Comparing graphs...')
|
|
70
|
+
# 基于graph、stack和data进行比较
|
|
71
|
+
graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
|
|
72
|
+
# 增加micro step标记
|
|
73
|
+
micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph)
|
|
73
74
|
# 开启溢出检测
|
|
74
75
|
if args.overflow_check:
|
|
75
|
-
graph_n.overflow_check()
|
|
76
|
-
graph_b.overflow_check()
|
|
76
|
+
graph_n.graph.overflow_check()
|
|
77
|
+
graph_b.graph.overflow_check()
|
|
77
78
|
|
|
78
|
-
return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps)
|
|
79
|
+
return CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
|
|
79
80
|
|
|
80
81
|
|
|
81
|
-
def _export_compare_graph_result(args,
|
|
82
|
-
|
|
83
|
-
|
|
82
|
+
def _export_compare_graph_result(args, result):
|
|
83
|
+
graphs = [result.graph_n, result.graph_b]
|
|
84
|
+
graph_comparator = result.graph_comparator
|
|
85
|
+
micro_steps = result.micro_steps
|
|
86
|
+
output_file_name = result.output_file_name
|
|
87
|
+
if not output_file_name:
|
|
88
|
+
output_file_name = f'compare_{current_time}.vis'
|
|
89
|
+
logger.info(f'Start exporting compare graph result, file name: {output_file_name}...')
|
|
84
90
|
output_path = os.path.join(args.output_path, output_file_name)
|
|
85
91
|
task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
|
|
86
92
|
export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
|
|
87
93
|
NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
|
|
88
|
-
args.overflow_check)
|
|
89
|
-
|
|
90
|
-
|
|
94
|
+
args.overflow_check, graph_comparator.ma.compare_mode)
|
|
95
|
+
try:
|
|
96
|
+
GraphBuilder.to_json(output_path, export_config)
|
|
97
|
+
logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_path}')
|
|
98
|
+
return ''
|
|
99
|
+
except RuntimeError as e:
|
|
100
|
+
logger.error(f'Failed to export compare graph result, file: {output_file_name}, error: {e}')
|
|
101
|
+
return output_file_name
|
|
91
102
|
|
|
92
103
|
|
|
93
|
-
def
|
|
94
|
-
logger.info('Start building model graph...')
|
|
104
|
+
def _build_graph_info(dump_path, args):
|
|
95
105
|
construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
|
|
96
106
|
FileCheckConst.READ_ABLE).common_check()
|
|
97
107
|
data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
@@ -99,6 +109,13 @@ def _build_graph(dump_path, args):
|
|
|
99
109
|
stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
100
110
|
FileCheckConst.READ_ABLE).common_check()
|
|
101
111
|
graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
|
|
112
|
+
return GraphInfo(graph, construct_path, data_path, stack_path)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _build_graph_result(dump_path, args):
|
|
116
|
+
logger.info('Start building model graphs...')
|
|
117
|
+
graph = _build_graph_info(dump_path, args).graph
|
|
118
|
+
# 增加micro step标记
|
|
102
119
|
micro_steps = graph.paging_by_micro_step()
|
|
103
120
|
# 开启溢出检测
|
|
104
121
|
if args.overflow_check:
|
|
@@ -106,15 +123,128 @@ def _build_graph(dump_path, args):
|
|
|
106
123
|
return BuildGraphResult(graph, micro_steps)
|
|
107
124
|
|
|
108
125
|
|
|
109
|
-
def
|
|
110
|
-
|
|
111
|
-
|
|
126
|
+
def _run_build_graph_compare(input_param, args, nr, br):
|
|
127
|
+
logger.info(f'Start building graph for {nr}...')
|
|
128
|
+
graph_n = _build_graph_info(input_param.get('npu_path'), args)
|
|
129
|
+
graph_b = _build_graph_info(input_param.get('bench_path'), args)
|
|
130
|
+
logger.info(f'Building graph for {nr} finished.')
|
|
131
|
+
return BuildGraphTaskInfo(graph_n, graph_b, nr, br, current_time)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _run_build_graph_single(dump_ranks_path, rank, step, args):
|
|
135
|
+
logger.info(f'Start building graph for {rank}...')
|
|
136
|
+
dump_path = os.path.join(dump_ranks_path, rank)
|
|
137
|
+
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
138
|
+
result = _build_graph_result(dump_path, args)
|
|
139
|
+
result.output_file_name = output_file_name
|
|
140
|
+
if rank != Const.RANK:
|
|
141
|
+
try:
|
|
142
|
+
result.rank = int(rank.replace(Const.RANK, ""))
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
145
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
146
|
+
logger.info(f'Building graph for step: {step}, rank: {rank} finished.')
|
|
147
|
+
return result
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _run_graph_compare(graph_task_info, input_param, args, output_file_name):
|
|
151
|
+
logger.info(f'Start comparing data for {graph_task_info.npu_rank}...')
|
|
152
|
+
graph_n = graph_task_info.graph_info_n
|
|
153
|
+
graph_b = graph_task_info.graph_info_b
|
|
154
|
+
nr = graph_task_info.npu_rank
|
|
155
|
+
graph_comparator = _compare_graph(graph_n, graph_b, input_param, args)
|
|
156
|
+
micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph)
|
|
157
|
+
# 开启溢出检测
|
|
158
|
+
if args.overflow_check:
|
|
159
|
+
graph_n.graph.overflow_check()
|
|
160
|
+
graph_b.graph.overflow_check()
|
|
161
|
+
graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
|
|
162
|
+
graph_result.output_file_name = output_file_name
|
|
163
|
+
if nr != Const.RANK:
|
|
164
|
+
try:
|
|
165
|
+
graph_result.rank = int(nr.replace(Const.RANK, ""))
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
168
|
+
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
169
|
+
logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.')
|
|
170
|
+
return graph_result
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _export_build_graph_result(args, result):
|
|
174
|
+
out_path = args.output_path
|
|
175
|
+
graph = result.graph
|
|
176
|
+
micro_steps = result.micro_steps
|
|
177
|
+
overflow_check = args.overflow_check
|
|
178
|
+
output_file_name = result.output_file_name
|
|
179
|
+
if not output_file_name:
|
|
180
|
+
output_file_name = f'build_{current_time}.vis'
|
|
181
|
+
logger.info(f'Start exporting graph for {output_file_name}...')
|
|
112
182
|
output_path = os.path.join(out_path, output_file_name)
|
|
113
|
-
|
|
114
|
-
|
|
183
|
+
try:
|
|
184
|
+
GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps,
|
|
185
|
+
overflow_check=overflow_check))
|
|
186
|
+
logger.info(f'Model graph exported successfully, the result file is saved in {output_path}')
|
|
187
|
+
return None
|
|
188
|
+
except RuntimeError as e:
|
|
189
|
+
logger.error(f'Failed to export model graph, file: {output_file_name}, error: {e}')
|
|
190
|
+
return output_file_name
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
194
|
+
dump_rank_n = input_param.get('npu_path')
|
|
195
|
+
dump_rank_b = input_param.get('bench_path')
|
|
196
|
+
has_real_data = False
|
|
197
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
198
|
+
dump_path_param = {
|
|
199
|
+
'npu_json_path': FileChecker(os.path.join(dump_rank_n, nr, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
200
|
+
FileCheckConst.READ_ABLE).common_check(),
|
|
201
|
+
'bench_json_path': FileChecker(os.path.join(dump_rank_b, br, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
202
|
+
FileCheckConst.READ_ABLE).common_check()
|
|
203
|
+
}
|
|
204
|
+
has_real_data |= get_dump_mode(dump_path_param) == Const.ALL
|
|
205
|
+
return has_real_data
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _mp_compare(input_param, serializable_args, output_file_name, nr, br):
|
|
209
|
+
graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br)
|
|
210
|
+
return _run_graph_compare(graph_task_info, input_param, serializable_args, output_file_name)
|
|
115
211
|
|
|
116
212
|
|
|
117
213
|
def _compare_graph_ranks(input_param, args, step=None):
|
|
214
|
+
with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
|
|
215
|
+
def err_call(err):
|
|
216
|
+
logger.error(f'Error occurred while comparing graph ranks: {err}')
|
|
217
|
+
try:
|
|
218
|
+
pool.close()
|
|
219
|
+
except OSError as e:
|
|
220
|
+
logger.error(f'Error occurred while terminating the pool: {e}')
|
|
221
|
+
|
|
222
|
+
serializable_args = SerializableArgs(args)
|
|
223
|
+
# 暂存所有rank的graph,用于匹配rank间的分布式节点
|
|
224
|
+
compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call)
|
|
225
|
+
|
|
226
|
+
# 匹配rank间的分布式节点
|
|
227
|
+
if len(compare_graph_results) > 1:
|
|
228
|
+
DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
|
|
229
|
+
args.overflow_check).distributed_match()
|
|
230
|
+
DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results},
|
|
231
|
+
args.overflow_check).distributed_match()
|
|
232
|
+
|
|
233
|
+
export_res_task_list = []
|
|
234
|
+
create_directory(args.output_path)
|
|
235
|
+
for result in compare_graph_results:
|
|
236
|
+
export_res_task_list.append(pool.apply_async(_export_compare_graph_result,
|
|
237
|
+
args=(serializable_args, result),
|
|
238
|
+
error_callback=err_call))
|
|
239
|
+
export_res_list = [res.get() for res in export_res_task_list]
|
|
240
|
+
if any(export_res_list):
|
|
241
|
+
failed_names = list(filter(lambda x: x, export_res_list))
|
|
242
|
+
logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.')
|
|
243
|
+
else:
|
|
244
|
+
logger.info('Successfully exported compare graph results.')
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _get_compare_graph_results(input_param, serializable_args, step, pool, err_call):
|
|
118
248
|
dump_rank_n = input_param.get('npu_path')
|
|
119
249
|
dump_rank_b = input_param.get('bench_path')
|
|
120
250
|
npu_ranks = sorted(check_and_return_dir_contents(dump_rank_n, Const.RANK))
|
|
@@ -123,32 +253,33 @@ def _compare_graph_ranks(input_param, args, step=None):
|
|
|
123
253
|
logger.error('The number of ranks in the two runs are different. Unable to match the ranks.')
|
|
124
254
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
125
255
|
compare_graph_results = []
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
256
|
+
if is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
257
|
+
mp_task_dict = {}
|
|
258
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
259
|
+
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
260
|
+
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
261
|
+
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
262
|
+
input_param_copy = deepcopy(input_param)
|
|
263
|
+
mp_task_dict[output_file_name] = pool.apply_async(_run_build_graph_compare,
|
|
264
|
+
args=(input_param_copy, serializable_args, nr, br),
|
|
265
|
+
error_callback=err_call)
|
|
266
|
+
|
|
267
|
+
mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()}
|
|
268
|
+
for output_file_name, mp_res in mp_res_dict.items():
|
|
269
|
+
compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args, output_file_name))
|
|
270
|
+
else:
|
|
271
|
+
compare_graph_tasks = []
|
|
272
|
+
for nr, br in zip(npu_ranks, bench_ranks):
|
|
273
|
+
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
274
|
+
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
275
|
+
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
276
|
+
input_param_copy = deepcopy(input_param)
|
|
277
|
+
compare_graph_tasks.append(pool.apply_async(_mp_compare,
|
|
278
|
+
args=(input_param_copy, serializable_args, output_file_name, nr,
|
|
279
|
+
br),
|
|
280
|
+
error_callback=err_call))
|
|
281
|
+
compare_graph_results = [task.get() for task in compare_graph_tasks]
|
|
282
|
+
return compare_graph_results
|
|
152
283
|
|
|
153
284
|
|
|
154
285
|
def _compare_graph_steps(input_param, args):
|
|
@@ -159,7 +290,7 @@ def _compare_graph_steps(input_param, args):
|
|
|
159
290
|
bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
|
|
160
291
|
|
|
161
292
|
if npu_steps != bench_steps:
|
|
162
|
-
logger.error('The number of steps in the two runs
|
|
293
|
+
logger.error('The number of steps in the two runs is different. Unable to match the steps.')
|
|
163
294
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
164
295
|
|
|
165
296
|
for folder_step in npu_steps:
|
|
@@ -172,28 +303,39 @@ def _compare_graph_steps(input_param, args):
|
|
|
172
303
|
|
|
173
304
|
def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
174
305
|
ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
180
|
-
result = _build_graph(dump_path, args)
|
|
181
|
-
result.output_file_name = output_file_name
|
|
182
|
-
if rank != Const.RANK:
|
|
306
|
+
serializable_args = SerializableArgs(args)
|
|
307
|
+
with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
|
|
308
|
+
def err_call(err):
|
|
309
|
+
logger.error(f'Error occurred while comparing graph ranks: {err}')
|
|
183
310
|
try:
|
|
184
|
-
|
|
185
|
-
except
|
|
186
|
-
logger.error('
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
311
|
+
pool.close()
|
|
312
|
+
except OSError as e:
|
|
313
|
+
logger.error(f'Error occurred while terminating the pool: {e}')
|
|
314
|
+
|
|
315
|
+
build_graph_tasks = []
|
|
316
|
+
for rank in ranks:
|
|
317
|
+
build_graph_tasks.append(pool.apply_async(_run_build_graph_single,
|
|
318
|
+
args=(dump_ranks_path, rank, step, serializable_args),
|
|
319
|
+
error_callback=err_call))
|
|
320
|
+
build_graph_results = [task.get() for task in build_graph_tasks]
|
|
321
|
+
|
|
322
|
+
if len(build_graph_results) > 1:
|
|
323
|
+
DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
|
|
324
|
+
args.overflow_check).distributed_match()
|
|
325
|
+
|
|
326
|
+
create_directory(args.output_path)
|
|
327
|
+
export_build_graph_tasks = []
|
|
328
|
+
for result in build_graph_results:
|
|
329
|
+
export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result,
|
|
330
|
+
args=(serializable_args, result),
|
|
331
|
+
error_callback=err_call))
|
|
332
|
+
export_build_graph_result = [task.get() for task in export_build_graph_tasks]
|
|
333
|
+
if any(export_build_graph_result):
|
|
334
|
+
failed_names = list(filter(lambda x: x, export_build_graph_result))
|
|
335
|
+
logger.error(f'Unable to export build graph results: {failed_names}.')
|
|
336
|
+
else:
|
|
337
|
+
logger.info(f'Successfully exported build graph results.')
|
|
193
338
|
|
|
194
|
-
for result in build_graph_results:
|
|
195
|
-
_export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check,
|
|
196
|
-
result.output_file_name)
|
|
197
339
|
|
|
198
340
|
|
|
199
341
|
def _build_graph_steps(dump_steps_path, args):
|
|
@@ -209,7 +351,7 @@ def _graph_service_parser(parser):
|
|
|
209
351
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
210
352
|
parser.add_argument("-o", "--output_path", dest="output_path", type=str,
|
|
211
353
|
help="<Required> The compare task result out path.", required=True)
|
|
212
|
-
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str,
|
|
354
|
+
parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True,
|
|
213
355
|
help="<Optional> The layer mapping file path.", required=False)
|
|
214
356
|
parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true",
|
|
215
357
|
help="<Optional> whether open overflow_check for graph.", required=False)
|
|
@@ -233,8 +375,11 @@ def _graph_service_command(args):
|
|
|
233
375
|
elif content == GraphConst.STEPS:
|
|
234
376
|
_build_graph_steps(npu_path, args)
|
|
235
377
|
else:
|
|
236
|
-
result =
|
|
237
|
-
|
|
378
|
+
result = _build_graph_result(npu_path, args)
|
|
379
|
+
create_directory(args.output_path)
|
|
380
|
+
file_name = _export_build_graph_result(args, result)
|
|
381
|
+
if file_name:
|
|
382
|
+
logger.error('Failed to export model build graph.')
|
|
238
383
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
239
384
|
content_n = check_directory_content(npu_path)
|
|
240
385
|
content_b = check_directory_content(bench_path)
|
|
@@ -245,9 +390,11 @@ def _graph_service_command(args):
|
|
|
245
390
|
elif content_n == GraphConst.STEPS:
|
|
246
391
|
_compare_graph_steps(input_param, args)
|
|
247
392
|
else:
|
|
248
|
-
result =
|
|
249
|
-
|
|
250
|
-
|
|
393
|
+
result = _compare_graph_result(input_param, args)
|
|
394
|
+
create_directory(args.output_path)
|
|
395
|
+
file_name = _export_compare_graph_result(args, result)
|
|
396
|
+
if file_name:
|
|
397
|
+
logger.error('Failed to export model compare graph.')
|
|
251
398
|
else:
|
|
252
399
|
logger.error("The npu_path or bench_path should be a folder.")
|
|
253
400
|
raise CompareException(CompareException.INVALID_COMPARE_MODE)
|
msprobe/visualization/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024, Huawei Technologies Co., Ltd.
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -16,9 +16,10 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import re
|
|
18
18
|
import json
|
|
19
|
+
import pickle
|
|
19
20
|
from msprobe.core.common.file_utils import FileOpen
|
|
20
21
|
from msprobe.core.common.const import CompareConst, Const
|
|
21
|
-
from msprobe.core.
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def load_json_file(file_path):
|
|
@@ -42,23 +43,6 @@ def load_data_json_file(file_path):
|
|
|
42
43
|
return load_json_file(file_path).get(GraphConst.DATA_KEY, {})
|
|
43
44
|
|
|
44
45
|
|
|
45
|
-
def save_json_file(file_path, data):
|
|
46
|
-
"""
|
|
47
|
-
保存json文件
|
|
48
|
-
"""
|
|
49
|
-
with FileOpen(file_path, 'w') as f:
|
|
50
|
-
f.write(json.dumps(data, indent=4))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def get_csv_df(stack_mode, csv_data, compare_mode):
|
|
54
|
-
"""
|
|
55
|
-
调用acc接口写入csv
|
|
56
|
-
"""
|
|
57
|
-
dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
|
|
58
|
-
mode_config = ModeConfig(stack_mode=stack_mode, dump_mode=dump_mode)
|
|
59
|
-
return Comparator(mode_config).make_result_table(csv_data)
|
|
60
|
-
|
|
61
|
-
|
|
62
46
|
def str2float(percentage_str):
|
|
63
47
|
"""
|
|
64
48
|
百分比字符串转换转换为浮点型
|
|
@@ -73,14 +57,6 @@ def str2float(percentage_str):
|
|
|
73
57
|
return 0
|
|
74
58
|
|
|
75
59
|
|
|
76
|
-
def is_integer(s):
|
|
77
|
-
try:
|
|
78
|
-
int(s)
|
|
79
|
-
return True
|
|
80
|
-
except Exception:
|
|
81
|
-
return False
|
|
82
|
-
|
|
83
|
-
|
|
84
60
|
def check_directory_content(input_path):
|
|
85
61
|
"""
|
|
86
62
|
检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件
|
|
@@ -143,14 +119,12 @@ class ToolTip:
|
|
|
143
119
|
'当最大相对误差越接近0表示其计算的误差越小。'
|
|
144
120
|
'当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象'
|
|
145
121
|
)
|
|
146
|
-
SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差'
|
|
147
122
|
|
|
148
123
|
|
|
149
124
|
class GraphConst:
|
|
150
125
|
CONSTRUCT_FILE = 'construct.json'
|
|
151
126
|
DUMP_FILE = 'dump.json'
|
|
152
127
|
STACK_FILE = 'stack.json'
|
|
153
|
-
GRAPH_FILE = 'graph.vis'
|
|
154
128
|
ERROR_KEY = 'error_key'
|
|
155
129
|
SUMMARY_COMPARE = 0
|
|
156
130
|
MD5_COMPARE = 1
|
|
@@ -164,35 +138,22 @@ class GraphConst:
|
|
|
164
138
|
JSON_DATA_KEY = 'dump_data_dir'
|
|
165
139
|
JSON_TASK_KEY = 'task'
|
|
166
140
|
DATA_KEY = 'data'
|
|
167
|
-
REAL_DATA_TH = 0.1
|
|
168
|
-
MAX_RELATIVE_ERR_TH = 0.5
|
|
169
141
|
ROUND_TH = 6
|
|
170
142
|
JSON_INDEX_KEY = 'precision_index'
|
|
171
143
|
MATCHED_DISTRIBUTED = 'matched_distributed'
|
|
172
144
|
OVERFLOW_LEVEL = 'overflow_level'
|
|
173
145
|
MAX_INDEX_KEY = 1
|
|
174
146
|
MIN_INDEX_KEY = 0
|
|
175
|
-
SUGGEST_KEY = 'text'
|
|
176
|
-
TAG_NA = 'na'
|
|
177
|
-
OUTPUT_INDEX_TWO = -2
|
|
178
|
-
OUTPUT_INDEX_THREE = -3
|
|
179
|
-
OUTPUT_MIN_LEN = 3
|
|
180
147
|
INPUT = '.input.'
|
|
181
148
|
OUTPUT = '.output.'
|
|
182
149
|
STR_MAX_LEN = 50
|
|
183
|
-
SMALL_VALUE = 1e-3
|
|
184
150
|
MD5_INDEX_LIST = [CompareConst.RESULT]
|
|
185
|
-
REAL_DATA_INDEX_LIST =
|
|
186
|
-
|
|
187
|
-
SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF,
|
|
188
|
-
CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
|
|
189
|
-
CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
|
|
190
|
-
VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
|
|
151
|
+
REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX
|
|
152
|
+
SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX
|
|
191
153
|
APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
|
|
192
154
|
NULL = 'null'
|
|
193
155
|
NONE = 'None'
|
|
194
156
|
VALUE = 'value'
|
|
195
|
-
BRACE = '{}'
|
|
196
157
|
DESCRIPTION = 'description'
|
|
197
158
|
COLORS = 'Colors'
|
|
198
159
|
MICRO_STEPS = 'MicroSteps'
|
|
@@ -223,3 +184,24 @@ class GraphConst:
|
|
|
223
184
|
OP = 'op'
|
|
224
185
|
PEER = 'peer'
|
|
225
186
|
GROUP_ID = 'group_id'
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def is_serializable(obj):
|
|
190
|
+
"""
|
|
191
|
+
Check if an object is serializable
|
|
192
|
+
"""
|
|
193
|
+
try:
|
|
194
|
+
pickle.dumps(obj)
|
|
195
|
+
return True
|
|
196
|
+
except (pickle.PicklingError, AttributeError, TypeError):
|
|
197
|
+
return False
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.error('Unexpected error occurred while pickling obj.')
|
|
200
|
+
raise RuntimeError('Unexpected error occurred while pickling obj.') from e
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class SerializableArgs:
|
|
204
|
+
def __init__(self, args):
|
|
205
|
+
for k, v in vars(args).items():
|
|
206
|
+
if is_serializable(v):
|
|
207
|
+
setattr(self, k, v)
|