mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
import os
|
|
19
|
+
from itertools import dropwhile, chain
|
|
20
|
+
|
|
21
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.nan_analyze.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst,
|
|
25
|
+
analyze_anomaly_in_group)
|
|
26
|
+
from msprobe.nan_analyze.graph import DataNode, CommunicationNode
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NanAnalyzer:
|
|
30
|
+
def __init__(self, input_path, output_path):
|
|
31
|
+
self._input_path = input_path
|
|
32
|
+
self._output_path = output_path
|
|
33
|
+
self._paths = {}
|
|
34
|
+
self._resolve_input_path()
|
|
35
|
+
self._anomaly_nodes = [] # 记录所有异常节点
|
|
36
|
+
self._cache = FileCache()
|
|
37
|
+
self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id
|
|
38
|
+
self._after_comm_anomalies = {} # 记录各rank下发生在通信节点之后的异常计算节点
|
|
39
|
+
self._rank_comm_nodes_dict = {} # 记录各rank的通信节点
|
|
40
|
+
|
|
41
|
+
def analyze(self):
|
|
42
|
+
for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]:
|
|
43
|
+
analyze_func()
|
|
44
|
+
if self._anomaly_nodes:
|
|
45
|
+
self._gen_analyze_info()
|
|
46
|
+
return
|
|
47
|
+
logger.info('Cannot find any anomaly node, no need to generate analyze file.')
|
|
48
|
+
|
|
49
|
+
def _resolve_input_path(self):
|
|
50
|
+
contents = os.listdir(self._input_path)
|
|
51
|
+
for path in contents:
|
|
52
|
+
if not path.startswith('rank'):
|
|
53
|
+
continue
|
|
54
|
+
rank_str = path.strip('rank')
|
|
55
|
+
if not rank_str:
|
|
56
|
+
rank = 0
|
|
57
|
+
elif not rank_str.isdigit():
|
|
58
|
+
continue
|
|
59
|
+
else:
|
|
60
|
+
rank = int(rank_str)
|
|
61
|
+
dump_path = os.path.join(self._input_path, path, NanAnalyseConst.DUMP_FILE)
|
|
62
|
+
construct_path = os.path.join(self._input_path, path, NanAnalyseConst.CONSTRUCT_FILE)
|
|
63
|
+
stack_path = os.path.join(self._input_path, path, NanAnalyseConst.STACK_FILE)
|
|
64
|
+
self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path)
|
|
65
|
+
|
|
66
|
+
def _pre_analyze(self):
|
|
67
|
+
logger.info('Start searching anomaly node before communication.')
|
|
68
|
+
for path in self._paths.values():
|
|
69
|
+
dump_data = self._cache.load_json(path.dump_path).get('data')
|
|
70
|
+
if not dump_data:
|
|
71
|
+
logger.warning(f'Rank {path.rank} has no dump data!')
|
|
72
|
+
continue
|
|
73
|
+
for op_name, op_data in dump_data.items():
|
|
74
|
+
if is_communication_op(op_name):
|
|
75
|
+
self._first_comm_nodes[path.rank] = op_name
|
|
76
|
+
break
|
|
77
|
+
data_node = DataNode(op_name, path.rank, op_data)
|
|
78
|
+
if data_node.is_anomaly():
|
|
79
|
+
self._anomaly_nodes.append(data_node)
|
|
80
|
+
break
|
|
81
|
+
|
|
82
|
+
def _analyze(self):
|
|
83
|
+
logger.info('Start searching anomaly node during communication.')
|
|
84
|
+
self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths}
|
|
85
|
+
self._connect_comm_nodes()
|
|
86
|
+
self._pruning()
|
|
87
|
+
self._search_first_anomaly()
|
|
88
|
+
|
|
89
|
+
def _post_analyze(self):
|
|
90
|
+
logger.info('Start searching anomaly node after communication.')
|
|
91
|
+
for nodes in self._after_comm_anomalies.values():
|
|
92
|
+
if nodes:
|
|
93
|
+
self._anomaly_nodes.append(nodes[0])
|
|
94
|
+
|
|
95
|
+
def _gen_analyze_info(self):
|
|
96
|
+
if not os.path.exists(self._output_path):
|
|
97
|
+
make_dir(self._output_path)
|
|
98
|
+
file_name = f'anomaly_analyze_{time.time_ns()}.json'
|
|
99
|
+
result_file = os.path.join(self._output_path, file_name)
|
|
100
|
+
result_content = defaultdict(list)
|
|
101
|
+
for node in self._anomaly_nodes:
|
|
102
|
+
result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank]))
|
|
103
|
+
save_json(result_file, result_content, 2)
|
|
104
|
+
logger.info(f"The analyze result is saved in: {result_file}")
|
|
105
|
+
|
|
106
|
+
def _analyze_comm_nodes(self, rank):
|
|
107
|
+
path = self._paths[rank]
|
|
108
|
+
data = self._cache.load_json(path.dump_path).get('data')
|
|
109
|
+
communication_nodes = {}
|
|
110
|
+
if rank not in self._first_comm_nodes: # 此rank没有通信节点
|
|
111
|
+
return communication_nodes
|
|
112
|
+
last_node_id = None # 记录上一个通信节点的node_id
|
|
113
|
+
compute_ops = [] # 记录两个通信节点之间的计算节点
|
|
114
|
+
sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数
|
|
115
|
+
for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data):
|
|
116
|
+
node_id = f'{rank}.{op_name}'
|
|
117
|
+
op_data = data[op_name]
|
|
118
|
+
if is_communication_op(op_name):
|
|
119
|
+
comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer),
|
|
120
|
+
compute_ops=compute_ops)
|
|
121
|
+
if last_node_id:
|
|
122
|
+
communication_nodes.get(last_node_id).add_next(comm_node)
|
|
123
|
+
communication_nodes[node_id] = comm_node
|
|
124
|
+
last_node_id = node_id
|
|
125
|
+
compute_ops = []
|
|
126
|
+
sub_layer = 0
|
|
127
|
+
elif not is_ignore_op(op_name):
|
|
128
|
+
data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer)
|
|
129
|
+
if data_node.is_anomaly():
|
|
130
|
+
compute_ops.append(data_node)
|
|
131
|
+
sub_layer += 1
|
|
132
|
+
if compute_ops:
|
|
133
|
+
self._after_comm_anomalies[rank] = compute_ops
|
|
134
|
+
return communication_nodes
|
|
135
|
+
|
|
136
|
+
def _connect_comm_nodes(self):
|
|
137
|
+
searched_ranks = set()
|
|
138
|
+
for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]:
|
|
139
|
+
searched_ranks.add(rank)
|
|
140
|
+
seen_nodes = set()
|
|
141
|
+
for cur_node in nodes.values():
|
|
142
|
+
conn_info = cur_node.find_connected_nodes()
|
|
143
|
+
if not conn_info.get('ranks'):
|
|
144
|
+
conn_info['ranks'] = self._rank_comm_nodes_dict.keys()
|
|
145
|
+
if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes):
|
|
146
|
+
logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".')
|
|
147
|
+
|
|
148
|
+
def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes):
|
|
149
|
+
def connect():
|
|
150
|
+
seen_nodes.add(search_node.node_id)
|
|
151
|
+
if search_node.type == NanAnalyseConst.DST:
|
|
152
|
+
cur_node.add_dst(search_node)
|
|
153
|
+
elif search_node.type == NanAnalyseConst.SRC:
|
|
154
|
+
search_node.layer = cur_node.layer
|
|
155
|
+
search_node.add_dst(cur_node)
|
|
156
|
+
else:
|
|
157
|
+
cur_node.add_link(search_node)
|
|
158
|
+
|
|
159
|
+
found = cur_node.connected
|
|
160
|
+
for connected_rank in conn_info['ranks']:
|
|
161
|
+
if connected_rank in searched_ranks:
|
|
162
|
+
continue
|
|
163
|
+
tar_id_prefix = f'{connected_rank}.{conn_info["api"]}'
|
|
164
|
+
for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items():
|
|
165
|
+
if search_id in seen_nodes:
|
|
166
|
+
continue
|
|
167
|
+
if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')):
|
|
168
|
+
continue
|
|
169
|
+
search_conn_ranks = search_node.find_connected_nodes().get('ranks')
|
|
170
|
+
if ((not search_conn_ranks and search_node.api not in NanAnalyseConst.DIRECTED_API) or
|
|
171
|
+
cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank
|
|
172
|
+
connect()
|
|
173
|
+
found = True
|
|
174
|
+
break
|
|
175
|
+
return found
|
|
176
|
+
|
|
177
|
+
def _pruning(self):
|
|
178
|
+
deleted_node_id = []
|
|
179
|
+
for nodes in self._rank_comm_nodes_dict.values():
|
|
180
|
+
for node_id in list(nodes.keys()):
|
|
181
|
+
node = nodes[node_id]
|
|
182
|
+
if node.has_nan_inf() or node.compute_ops:
|
|
183
|
+
continue
|
|
184
|
+
deleted_node_id.append(node_id)
|
|
185
|
+
node.delete()
|
|
186
|
+
del nodes[node_id]
|
|
187
|
+
logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]')
|
|
188
|
+
|
|
189
|
+
def _search_first_anomaly(self):
|
|
190
|
+
nodes_queues = []
|
|
191
|
+
for comm_nodes in self._rank_comm_nodes_dict.values():
|
|
192
|
+
nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer))
|
|
193
|
+
seen_nodes = set()
|
|
194
|
+
|
|
195
|
+
def get_next_node(node_list):
|
|
196
|
+
while node_list:
|
|
197
|
+
next_node = node_list.pop(0)
|
|
198
|
+
if next_node.node_id not in seen_nodes:
|
|
199
|
+
return next_node
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
def find_all_members(ori_node):
|
|
203
|
+
ids = get_relative_ids(ori_node)
|
|
204
|
+
id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids]))
|
|
205
|
+
while id_queue:
|
|
206
|
+
new_id = id_queue.pop(0)
|
|
207
|
+
ids.add(new_id)
|
|
208
|
+
id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids))
|
|
209
|
+
return ids
|
|
210
|
+
|
|
211
|
+
def get_relative_ids(ori_node):
|
|
212
|
+
if not ori_node:
|
|
213
|
+
return set()
|
|
214
|
+
return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() |
|
|
215
|
+
ori_node.dst_nodes.keys())
|
|
216
|
+
|
|
217
|
+
while any(nodes_queues):
|
|
218
|
+
groups = []
|
|
219
|
+
all_ids_in_groups = set()
|
|
220
|
+
for nodes in nodes_queues:
|
|
221
|
+
node = get_next_node(nodes)
|
|
222
|
+
if not node:
|
|
223
|
+
continue
|
|
224
|
+
if not groups or node.node_id in all_ids_in_groups:
|
|
225
|
+
new_group = find_all_members(node)
|
|
226
|
+
groups.append(new_group)
|
|
227
|
+
all_ids_in_groups.update(new_group)
|
|
228
|
+
for group in groups:
|
|
229
|
+
seen_nodes.update(group)
|
|
230
|
+
self._anomaly_nodes.extend(analyze_anomaly_in_group([self._get_node_by_id(n_id) for n_id in group]))
|
|
231
|
+
if self._anomaly_nodes:
|
|
232
|
+
self._anomaly_nodes = [min(self._anomaly_nodes, key=lambda x: (x.layer, x.sub_layer))]
|
|
233
|
+
return
|
|
234
|
+
|
|
235
|
+
def _get_node_by_id(self, node_id):
|
|
236
|
+
splits = node_id.split(Const.SEP, 1)
|
|
237
|
+
if len(splits) < 2 or not splits[0].isdigit():
|
|
238
|
+
logger.error(f'invalid node_id {node_id}')
|
|
239
|
+
raise RuntimeError(f'invalid node_id {node_id}')
|
|
240
|
+
rank = int(splits[0])
|
|
241
|
+
return self._rank_comm_nodes_dict.get(rank, {}).get(node_id)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _nan_analyze_parser(parser):
|
|
245
|
+
parser.add_argument("-i", "--input_path", dest="input_path", default="", type=str,
|
|
246
|
+
help="<Required> The dump file path, over step level. eg: \"xxx/step_0/\".",
|
|
247
|
+
required=True)
|
|
248
|
+
parser.add_argument("-o", "--output_path", dest="output_path", default="./output", type=str,
|
|
249
|
+
help="<optional> The nan inf analyze result output file path.",
|
|
250
|
+
required=False)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _run_nan_analyze(args):
|
|
254
|
+
check_file_or_directory_path(args.input_path, True)
|
|
255
|
+
NanAnalyzer(args.input_path, args.output_path).analyze()
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from msprobe.core.common.const import Const
|
|
18
|
+
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class DataNode:
|
|
24
|
+
op_name: str
|
|
25
|
+
rank: int
|
|
26
|
+
inputs: list
|
|
27
|
+
input_args: list
|
|
28
|
+
input_kwargs: dict
|
|
29
|
+
outputs: dict
|
|
30
|
+
layer: int = 0 # 和communication_node的layer保持一致
|
|
31
|
+
sub_layer: int = 0 # 调用顺序,越小表示越先调用
|
|
32
|
+
|
|
33
|
+
def __init__(self, op_name, rank, op_data, **kwargs):
|
|
34
|
+
self.op_name = op_name
|
|
35
|
+
self.rank = rank
|
|
36
|
+
self.inputs = op_data.get(Const.INPUT, [])
|
|
37
|
+
self.input_args = op_data.get(Const.INPUT_ARGS, [])
|
|
38
|
+
self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {})
|
|
39
|
+
self.outputs = op_data.get(Const.OUTPUT, {})
|
|
40
|
+
self.sub_layer = kwargs.get('sub_layer', 0)
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def find_complete_construct(construct_info, op_name):
|
|
44
|
+
construct = [op_name]
|
|
45
|
+
seen = set(op_name)
|
|
46
|
+
while True:
|
|
47
|
+
op_name = construct_info.get(op_name)
|
|
48
|
+
if not op_name or op_name in seen:
|
|
49
|
+
return construct
|
|
50
|
+
construct.insert(0, op_name)
|
|
51
|
+
seen.add(op_name)
|
|
52
|
+
|
|
53
|
+
def find_stack(self, stack_info):
|
|
54
|
+
for item in stack_info.values():
|
|
55
|
+
if len(item) >= 2 and self.op_name in item[0]:
|
|
56
|
+
return item[1]
|
|
57
|
+
return {}
|
|
58
|
+
|
|
59
|
+
def is_anomaly(self) -> bool:
|
|
60
|
+
if is_ignore_op(self.op_name):
|
|
61
|
+
return False
|
|
62
|
+
is_input_anomaly = (check_item_anomaly(self.inputs) or check_item_anomaly(self.input_args) or
|
|
63
|
+
check_item_anomaly(self.input_kwargs))
|
|
64
|
+
is_output_anomaly = check_item_anomaly(self.outputs)
|
|
65
|
+
return (not is_input_anomaly) and is_output_anomaly
|
|
66
|
+
|
|
67
|
+
def gen_node_info(self, path: RankPath):
|
|
68
|
+
cache = FileCache()
|
|
69
|
+
construct = cache.load_json(path.construct_path)
|
|
70
|
+
stack = cache.load_json(path.stack_path)
|
|
71
|
+
if Const.FORWARD in self.op_name:
|
|
72
|
+
data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs,
|
|
73
|
+
Const.OUTPUT: self.outputs}
|
|
74
|
+
else:
|
|
75
|
+
data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs}
|
|
76
|
+
return {'op_name': self.op_name,
|
|
77
|
+
'data_info': data_info_list,
|
|
78
|
+
'construct_info': self.find_complete_construct(construct, self.op_name),
|
|
79
|
+
'stack_info': self.find_stack(stack)}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class CommunicationNode:
|
|
83
|
+
def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs):
|
|
84
|
+
self.node_id = node_id
|
|
85
|
+
self.rank = rank
|
|
86
|
+
self.data = data
|
|
87
|
+
self.layer = layer
|
|
88
|
+
op_name_split = self.data.op_name.split(Const.SEP)
|
|
89
|
+
if len(op_name_split) < 4:
|
|
90
|
+
logger.error(f'invalid op_name: {self.data.op_name}')
|
|
91
|
+
raise RuntimeError(f'invalid op_name: {self.data.op_name}')
|
|
92
|
+
self.api = op_name_split[1]
|
|
93
|
+
self.call_cnt = op_name_split[2]
|
|
94
|
+
self.pre_node = kwargs.get('pre_node')
|
|
95
|
+
self.link_nodes = kwargs.get('link_nodes', {})
|
|
96
|
+
self.dst_nodes = kwargs.get('dst_nodes', {})
|
|
97
|
+
self.src_nodes = kwargs.get('src_nodes', {})
|
|
98
|
+
self.next_nodes = kwargs.get('next_nodes', {})
|
|
99
|
+
self.compute_ops = kwargs.get('compute_ops', [])
|
|
100
|
+
self.type = self._resolve_type()
|
|
101
|
+
self.connected = False
|
|
102
|
+
|
|
103
|
+
def add_next(self, node):
|
|
104
|
+
self.next_nodes[node.node_id] = node
|
|
105
|
+
node.pre_node = self
|
|
106
|
+
node.layer = self.layer + 1
|
|
107
|
+
node.data.layer = node.layer
|
|
108
|
+
|
|
109
|
+
def add_link(self, node):
|
|
110
|
+
self.link_nodes[node.node_id] = node
|
|
111
|
+
node.link_nodes[self.node_id] = self
|
|
112
|
+
node.layer = self.layer
|
|
113
|
+
node.data.layer = node.layer
|
|
114
|
+
self.connected = True
|
|
115
|
+
node.connected = True
|
|
116
|
+
|
|
117
|
+
def add_dst(self, node):
|
|
118
|
+
self.dst_nodes[node.node_id] = node
|
|
119
|
+
node.src_nodes[self.node_id] = self
|
|
120
|
+
node.layer = self.layer
|
|
121
|
+
node.data.layer = node.layer
|
|
122
|
+
self.connected = True
|
|
123
|
+
node.connected = True
|
|
124
|
+
|
|
125
|
+
def delete(self):
|
|
126
|
+
for node in self.next_nodes.values():
|
|
127
|
+
node.pre_node = None
|
|
128
|
+
for node in self.dst_nodes.values():
|
|
129
|
+
node.src_nodes.pop(self.node_id)
|
|
130
|
+
for node in self.src_nodes.values():
|
|
131
|
+
node.dst_nodes.pop(self.node_id)
|
|
132
|
+
for node in self.link_nodes.values():
|
|
133
|
+
node.link_nodes.pop(self.node_id)
|
|
134
|
+
if self.pre_node:
|
|
135
|
+
self.pre_node.next_nodes.pop(self.node_id)
|
|
136
|
+
|
|
137
|
+
def has_nan_inf(self):
|
|
138
|
+
return self.input_has_nan_inf() or check_item_anomaly(self.data.outputs)
|
|
139
|
+
|
|
140
|
+
def input_has_nan_inf(self):
|
|
141
|
+
return check_item_anomaly(self.data.input_args) or check_item_anomaly(self.data.input_kwargs)
|
|
142
|
+
|
|
143
|
+
def find_connected_nodes(self):
|
|
144
|
+
"""
|
|
145
|
+
根据 api/类型/入参/调用次数 确定相连接的node的op_name
|
|
146
|
+
"""
|
|
147
|
+
tar_api = NanAnalyseConst.P2P_API_MAPPING.get(self.api, self.api)
|
|
148
|
+
ranks = set()
|
|
149
|
+
for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]:
|
|
150
|
+
if dst in self.data.input_kwargs:
|
|
151
|
+
dst_value = self.data.input_kwargs.get(dst)
|
|
152
|
+
if dst_value:
|
|
153
|
+
ranks.add(dst_value.get('value'))
|
|
154
|
+
break
|
|
155
|
+
for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]:
|
|
156
|
+
if src in self.data.input_kwargs:
|
|
157
|
+
src_value = self.data.input_kwargs.get(src)
|
|
158
|
+
if src_value:
|
|
159
|
+
ranks.add(src_value.get('value'))
|
|
160
|
+
break
|
|
161
|
+
if not ranks:
|
|
162
|
+
for item in self.data.input_args:
|
|
163
|
+
if isinstance(item, dict) and item.get(Const.TYPE) == 'int':
|
|
164
|
+
ranks.add(item.get('value'))
|
|
165
|
+
group = self.data.input_kwargs.get('group')
|
|
166
|
+
if group:
|
|
167
|
+
ranks.update(group.get('group_ranks'))
|
|
168
|
+
return {'ranks': ranks, 'api': f'Distributed.{tar_api}',
|
|
169
|
+
'type': NanAnalyseConst.OPPOSITE_DIR.get(self.type, NanAnalyseConst.LINK)}
|
|
170
|
+
|
|
171
|
+
def _resolve_type(self):
|
|
172
|
+
for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]:
|
|
173
|
+
if src in self.data.input_kwargs and self.data.input_kwargs[src]:
|
|
174
|
+
if self.data.input_kwargs[src].get('value') == self.rank:
|
|
175
|
+
return NanAnalyseConst.SRC
|
|
176
|
+
else:
|
|
177
|
+
return NanAnalyseConst.DST
|
|
178
|
+
for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]:
|
|
179
|
+
if dst in self.data.input_kwargs and self.data.input_kwargs[dst]:
|
|
180
|
+
if self.data.input_kwargs[dst].get('value') == self.rank:
|
|
181
|
+
return NanAnalyseConst.DST
|
|
182
|
+
else:
|
|
183
|
+
return NanAnalyseConst.SRC
|
|
184
|
+
if self.api in NanAnalyseConst.DIRECTED_API:
|
|
185
|
+
for item in self.data.input_args:
|
|
186
|
+
if item.get(Const.TYPE) == 'int':
|
|
187
|
+
node_type = NanAnalyseConst.DIRECTED_API[self.api]
|
|
188
|
+
return node_type if item.get('value') == self.rank else NanAnalyseConst.OPPOSITE_DIR[node_type]
|
|
189
|
+
return NanAnalyseConst.LINK
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright (c) 2025, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
import psutil
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import CompareConst
|
|
23
|
+
from msprobe.core.common.file_utils import check_file_or_directory_path, load_json
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class RankPath:
|
|
28
|
+
rank: int
|
|
29
|
+
dump_path: str
|
|
30
|
+
construct_path: str
|
|
31
|
+
stack_path: str
|
|
32
|
+
|
|
33
|
+
def __init__(self, rank, dump_path, construct_path, stack_path):
|
|
34
|
+
self.rank = rank
|
|
35
|
+
check_file_or_directory_path(dump_path)
|
|
36
|
+
self.dump_path = dump_path
|
|
37
|
+
check_file_or_directory_path(construct_path)
|
|
38
|
+
self.construct_path = construct_path
|
|
39
|
+
check_file_or_directory_path(stack_path)
|
|
40
|
+
self.stack_path = stack_path
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class FileCache:
|
|
44
|
+
"""
|
|
45
|
+
lazy load file
|
|
46
|
+
"""
|
|
47
|
+
_instance = None
|
|
48
|
+
|
|
49
|
+
def __new__(cls, *args, **kwargs):
|
|
50
|
+
if not cls._instance:
|
|
51
|
+
cls._instance = super().__new__(cls, *args, **kwargs)
|
|
52
|
+
return cls._instance
|
|
53
|
+
|
|
54
|
+
def __init__(self):
|
|
55
|
+
self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4
|
|
56
|
+
self._cache = OrderedDict()
|
|
57
|
+
self._access_cnt = {}
|
|
58
|
+
self._access_time = {}
|
|
59
|
+
self._size = {}
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _sizeof(obj):
|
|
63
|
+
seen = set()
|
|
64
|
+
objs = [obj]
|
|
65
|
+
size = 0
|
|
66
|
+
while objs:
|
|
67
|
+
obj = objs.pop()
|
|
68
|
+
obj_id = id(obj)
|
|
69
|
+
if obj_id in seen:
|
|
70
|
+
continue
|
|
71
|
+
seen.add(obj_id)
|
|
72
|
+
size += sys.getsizeof(obj)
|
|
73
|
+
if isinstance(obj, dict):
|
|
74
|
+
objs.extend(obj.keys())
|
|
75
|
+
objs.extend(obj.values())
|
|
76
|
+
elif isinstance(obj, (list, tuple, set, frozenset)):
|
|
77
|
+
objs.extend(obj)
|
|
78
|
+
return size
|
|
79
|
+
|
|
80
|
+
def load_json(self, json_path):
|
|
81
|
+
if json_path in self._cache:
|
|
82
|
+
self._access_cnt[json_path] += 1
|
|
83
|
+
self._access_time[json_path] = time.monotonic()
|
|
84
|
+
self._cache.move_to_end(json_path)
|
|
85
|
+
return self._cache[json_path]
|
|
86
|
+
self._cleanup()
|
|
87
|
+
return self._load(json_path)
|
|
88
|
+
|
|
89
|
+
def _load(self, json_path):
|
|
90
|
+
data = load_json(json_path)
|
|
91
|
+
self._add_to_cache(json_path, data)
|
|
92
|
+
return data
|
|
93
|
+
|
|
94
|
+
def _add_to_cache(self, key, data):
|
|
95
|
+
if key in self._cache:
|
|
96
|
+
self._cache.move_to_end(key)
|
|
97
|
+
else:
|
|
98
|
+
self._cache[key] = data
|
|
99
|
+
self._access_cnt[key] = 0
|
|
100
|
+
self._access_time[key] = time.monotonic()
|
|
101
|
+
self._size[key] = self._sizeof(data)
|
|
102
|
+
|
|
103
|
+
def _calc_cache_size(self):
|
|
104
|
+
return sys.getsizeof(self._cache) + sum(self._size.values())
|
|
105
|
+
|
|
106
|
+
def _cleanup(self):
|
|
107
|
+
while self._calc_cache_size() > self._max_memory_usage and self._cache:
|
|
108
|
+
least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k])
|
|
109
|
+
least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k])
|
|
110
|
+
largest_key = max(self._cache.keys(), key=lambda k: self._size[k])
|
|
111
|
+
key_to_rm = min([least_frequent_key, least_recent_key, largest_key],
|
|
112
|
+
key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k]))
|
|
113
|
+
del self._cache[key_to_rm]
|
|
114
|
+
del self._access_cnt[key_to_rm]
|
|
115
|
+
del self._access_time[key_to_rm]
|
|
116
|
+
del self._size[key_to_rm]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def is_communication_op(op_name):
|
|
120
|
+
# 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等
|
|
121
|
+
# 从wrap文件中读取,先硬编码在文件中
|
|
122
|
+
return (op_name.startswith('Distributed.') and
|
|
123
|
+
any(keyword in op_name for keyword in NanAnalyseConst.COMMUNICATION_KEYWORDS))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def is_ignore_op(op_name):
|
|
127
|
+
ignore_keywords = [
|
|
128
|
+
'Torch.empty',
|
|
129
|
+
'Torch.fill'
|
|
130
|
+
]
|
|
131
|
+
return any(keyword in op_name for keyword in ignore_keywords)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def check_item_anomaly(param):
|
|
135
|
+
def has_nan_inf(dict_obj, key):
|
|
136
|
+
return str(dict_obj.get(key)).lower() in CompareConst.OVERFLOW_LIST
|
|
137
|
+
|
|
138
|
+
items = []
|
|
139
|
+
if isinstance(param, list):
|
|
140
|
+
items = param
|
|
141
|
+
elif isinstance(param, dict):
|
|
142
|
+
items = param.values()
|
|
143
|
+
for item in items:
|
|
144
|
+
if not isinstance(item, dict):
|
|
145
|
+
continue
|
|
146
|
+
if has_nan_inf(item, 'Max') or has_nan_inf(item, 'Min'):
|
|
147
|
+
return True
|
|
148
|
+
return False
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def analyze_anomaly_in_group(nodes_group):
|
|
152
|
+
anomaly_nodes = []
|
|
153
|
+
|
|
154
|
+
def get_compute_ops_from_comm_nodes(comm_nodes):
|
|
155
|
+
for comm_node in comm_nodes:
|
|
156
|
+
for op_node in comm_node.compute_ops:
|
|
157
|
+
op_node.layer = comm_node.layer
|
|
158
|
+
anomaly_nodes.append(op_node)
|
|
159
|
+
|
|
160
|
+
def get_comm_ops(comm_nodes):
|
|
161
|
+
for node in comm_nodes:
|
|
162
|
+
node.data.layer = node.layer
|
|
163
|
+
anomaly_nodes.append(node.data)
|
|
164
|
+
|
|
165
|
+
# 先看src或link中input是否有异常
|
|
166
|
+
src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group))
|
|
167
|
+
input_anomaly_nodes = list(filter(lambda node: node.input_has_nan_inf(), src_list))
|
|
168
|
+
# 如果有异常回溯计算节点找到异常来源
|
|
169
|
+
# 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
|
|
170
|
+
get_compute_ops_from_comm_nodes(input_anomaly_nodes)
|
|
171
|
+
# 筛选入参没问题但出参有问题的通信节点
|
|
172
|
+
output_anomaly_nodes = list(filter(lambda node: node.data.is_anomaly(), nodes_group))
|
|
173
|
+
get_comm_ops(output_anomaly_nodes)
|
|
174
|
+
return anomaly_nodes
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class NanAnalyseConst:
|
|
178
|
+
COMMUNICATION_KEYWORDS = {
|
|
179
|
+
'send', # send 算子
|
|
180
|
+
'recv', # recv 算子
|
|
181
|
+
'broadcast', # broadcast 算子
|
|
182
|
+
'all_reduce', # all_reduce 算子
|
|
183
|
+
'reduce', # reduce 算子
|
|
184
|
+
'all_gather', # all_gather 算子
|
|
185
|
+
'gather', # gather 算子
|
|
186
|
+
'isend', # isend 算子
|
|
187
|
+
'irecv', # irecv 算子
|
|
188
|
+
'scatter', # scatter 算子
|
|
189
|
+
'reduce_scatter', # reduce_scatter 算子
|
|
190
|
+
'_reduce_scatter_base', # _reduce_scatter_base 算子
|
|
191
|
+
'_all_gather_base', # _all_gather_base 算子
|
|
192
|
+
'all_to_all_single', # all_to_all_single 算子
|
|
193
|
+
'all_to_all', # all_to_all 算子
|
|
194
|
+
'all_gather_into_tensor', # all_gather_into_tensor 算子
|
|
195
|
+
'reduce_scatter_tensor', # reduce_scatter_tensor 算子
|
|
196
|
+
'send_object_list', # send_object_list 算子
|
|
197
|
+
'recv_object_list' # recv_object_list 算子
|
|
198
|
+
}
|
|
199
|
+
P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend',
|
|
200
|
+
'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'}
|
|
201
|
+
SRC = 'src'
|
|
202
|
+
DST = 'dst'
|
|
203
|
+
SRC_GROUP = 'src_group'
|
|
204
|
+
DST_GROUP = 'dst_group'
|
|
205
|
+
LINK = 'link'
|
|
206
|
+
DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC,
|
|
207
|
+
'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC}
|
|
208
|
+
OPPOSITE_DIR = {SRC: DST, DST: SRC}
|
|
209
|
+
DUMP_FILE = "dump.json"
|
|
210
|
+
CONSTRUCT_FILE = "construct.json"
|
|
211
|
+
STACK_FILE = "stack.json"
|
|
@@ -125,8 +125,8 @@ class CheckerConfig:
|
|
|
125
125
|
save_error_data=config_params.get('save_error_data'),
|
|
126
126
|
is_continue_run_ut=config_params.get('is_continue_run_ut'),
|
|
127
127
|
real_data_path=config_params.get('real_data_path'),
|
|
128
|
-
white_list=self.white_list,
|
|
129
|
-
black_list=self.black_list,
|
|
128
|
+
white_list=self.white_list.copy() if self.white_list else [],
|
|
129
|
+
black_list=self.black_list.copy() if self.black_list else [],
|
|
130
130
|
error_data_path=config_params.get('error_data_path'),
|
|
131
131
|
online_config=self.get_online_config()
|
|
132
132
|
)
|