mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -0,0 +1,255 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import time
17
+ from collections import defaultdict
18
+ import os
19
+ from itertools import dropwhile, chain
20
+
21
+ from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.nan_analyze.utils import (RankPath, FileCache, is_communication_op, is_ignore_op, NanAnalyseConst,
25
+ analyze_anomaly_in_group)
26
+ from msprobe.nan_analyze.graph import DataNode, CommunicationNode
27
+
28
+
29
+ class NanAnalyzer:
30
+ def __init__(self, input_path, output_path):
31
+ self._input_path = input_path
32
+ self._output_path = output_path
33
+ self._paths = {}
34
+ self._resolve_input_path()
35
+ self._anomaly_nodes = [] # 记录所有异常节点
36
+ self._cache = FileCache()
37
+ self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id
38
+ self._after_comm_anomalies = {} # 记录各rank下发生在通信节点之后的异常计算节点
39
+ self._rank_comm_nodes_dict = {} # 记录各rank的通信节点
40
+
41
+ def analyze(self):
42
+ for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]:
43
+ analyze_func()
44
+ if self._anomaly_nodes:
45
+ self._gen_analyze_info()
46
+ return
47
+ logger.info('Cannot find any anomaly node, no need to generate analyze file.')
48
+
49
+ def _resolve_input_path(self):
50
+ contents = os.listdir(self._input_path)
51
+ for path in contents:
52
+ if not path.startswith('rank'):
53
+ continue
54
+ rank_str = path.strip('rank')
55
+ if not rank_str:
56
+ rank = 0
57
+ elif not rank_str.isdigit():
58
+ continue
59
+ else:
60
+ rank = int(rank_str)
61
+ dump_path = os.path.join(self._input_path, path, NanAnalyseConst.DUMP_FILE)
62
+ construct_path = os.path.join(self._input_path, path, NanAnalyseConst.CONSTRUCT_FILE)
63
+ stack_path = os.path.join(self._input_path, path, NanAnalyseConst.STACK_FILE)
64
+ self._paths[rank] = RankPath(rank, dump_path, construct_path, stack_path)
65
+
66
+ def _pre_analyze(self):
67
+ logger.info('Start searching anomaly node before communication.')
68
+ for path in self._paths.values():
69
+ dump_data = self._cache.load_json(path.dump_path).get('data')
70
+ if not dump_data:
71
+ logger.warning(f'Rank {path.rank} has no dump data!')
72
+ continue
73
+ for op_name, op_data in dump_data.items():
74
+ if is_communication_op(op_name):
75
+ self._first_comm_nodes[path.rank] = op_name
76
+ break
77
+ data_node = DataNode(op_name, path.rank, op_data)
78
+ if data_node.is_anomaly():
79
+ self._anomaly_nodes.append(data_node)
80
+ break
81
+
82
+ def _analyze(self):
83
+ logger.info('Start searching anomaly node during communication.')
84
+ self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths}
85
+ self._connect_comm_nodes()
86
+ self._pruning()
87
+ self._search_first_anomaly()
88
+
89
+ def _post_analyze(self):
90
+ logger.info('Start searching anomaly node after communication.')
91
+ for nodes in self._after_comm_anomalies.values():
92
+ if nodes:
93
+ self._anomaly_nodes.append(nodes[0])
94
+
95
+ def _gen_analyze_info(self):
96
+ if not os.path.exists(self._output_path):
97
+ make_dir(self._output_path)
98
+ file_name = f'anomaly_analyze_{time.time_ns()}.json'
99
+ result_file = os.path.join(self._output_path, file_name)
100
+ result_content = defaultdict(list)
101
+ for node in self._anomaly_nodes:
102
+ result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank]))
103
+ save_json(result_file, result_content, 2)
104
+ logger.info(f"The analyze result is saved in: {result_file}")
105
+
106
+ def _analyze_comm_nodes(self, rank):
107
+ path = self._paths[rank]
108
+ data = self._cache.load_json(path.dump_path).get('data')
109
+ communication_nodes = {}
110
+ if rank not in self._first_comm_nodes: # 此rank没有通信节点
111
+ return communication_nodes
112
+ last_node_id = None # 记录上一个通信节点的node_id
113
+ compute_ops = [] # 记录两个通信节点之间的计算节点
114
+ sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数
115
+ for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data):
116
+ node_id = f'{rank}.{op_name}'
117
+ op_data = data[op_name]
118
+ if is_communication_op(op_name):
119
+ comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer),
120
+ compute_ops=compute_ops)
121
+ if last_node_id:
122
+ communication_nodes.get(last_node_id).add_next(comm_node)
123
+ communication_nodes[node_id] = comm_node
124
+ last_node_id = node_id
125
+ compute_ops = []
126
+ sub_layer = 0
127
+ elif not is_ignore_op(op_name):
128
+ data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer)
129
+ if data_node.is_anomaly():
130
+ compute_ops.append(data_node)
131
+ sub_layer += 1
132
+ if compute_ops:
133
+ self._after_comm_anomalies[rank] = compute_ops
134
+ return communication_nodes
135
+
136
+ def _connect_comm_nodes(self):
137
+ searched_ranks = set()
138
+ for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]:
139
+ searched_ranks.add(rank)
140
+ seen_nodes = set()
141
+ for cur_node in nodes.values():
142
+ conn_info = cur_node.find_connected_nodes()
143
+ if not conn_info.get('ranks'):
144
+ conn_info['ranks'] = self._rank_comm_nodes_dict.keys()
145
+ if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes):
146
+ logger.info(f'Cannot find connected communication node for "{cur_node.node_id}".')
147
+
148
+ def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes):
149
+ def connect():
150
+ seen_nodes.add(search_node.node_id)
151
+ if search_node.type == NanAnalyseConst.DST:
152
+ cur_node.add_dst(search_node)
153
+ elif search_node.type == NanAnalyseConst.SRC:
154
+ search_node.layer = cur_node.layer
155
+ search_node.add_dst(cur_node)
156
+ else:
157
+ cur_node.add_link(search_node)
158
+
159
+ found = cur_node.connected
160
+ for connected_rank in conn_info['ranks']:
161
+ if connected_rank in searched_ranks:
162
+ continue
163
+ tar_id_prefix = f'{connected_rank}.{conn_info["api"]}'
164
+ for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items():
165
+ if search_id in seen_nodes:
166
+ continue
167
+ if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')):
168
+ continue
169
+ search_conn_ranks = search_node.find_connected_nodes().get('ranks')
170
+ if ((not search_conn_ranks and search_node.api not in NanAnalyseConst.DIRECTED_API) or
171
+ cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank
172
+ connect()
173
+ found = True
174
+ break
175
+ return found
176
+
177
+ def _pruning(self):
178
+ deleted_node_id = []
179
+ for nodes in self._rank_comm_nodes_dict.values():
180
+ for node_id in list(nodes.keys()):
181
+ node = nodes[node_id]
182
+ if node.has_nan_inf() or node.compute_ops:
183
+ continue
184
+ deleted_node_id.append(node_id)
185
+ node.delete()
186
+ del nodes[node_id]
187
+ logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]')
188
+
189
+ def _search_first_anomaly(self):
190
+ nodes_queues = []
191
+ for comm_nodes in self._rank_comm_nodes_dict.values():
192
+ nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer))
193
+ seen_nodes = set()
194
+
195
+ def get_next_node(node_list):
196
+ while node_list:
197
+ next_node = node_list.pop(0)
198
+ if next_node.node_id not in seen_nodes:
199
+ return next_node
200
+ return None
201
+
202
+ def find_all_members(ori_node):
203
+ ids = get_relative_ids(ori_node)
204
+ id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids]))
205
+ while id_queue:
206
+ new_id = id_queue.pop(0)
207
+ ids.add(new_id)
208
+ id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids))
209
+ return ids
210
+
211
+ def get_relative_ids(ori_node):
212
+ if not ori_node:
213
+ return set()
214
+ return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() |
215
+ ori_node.dst_nodes.keys())
216
+
217
+ while any(nodes_queues):
218
+ groups = []
219
+ all_ids_in_groups = set()
220
+ for nodes in nodes_queues:
221
+ node = get_next_node(nodes)
222
+ if not node:
223
+ continue
224
+ if not groups or node.node_id in all_ids_in_groups:
225
+ new_group = find_all_members(node)
226
+ groups.append(new_group)
227
+ all_ids_in_groups.update(new_group)
228
+ for group in groups:
229
+ seen_nodes.update(group)
230
+ self._anomaly_nodes.extend(analyze_anomaly_in_group([self._get_node_by_id(n_id) for n_id in group]))
231
+ if self._anomaly_nodes:
232
+ self._anomaly_nodes = [min(self._anomaly_nodes, key=lambda x: (x.layer, x.sub_layer))]
233
+ return
234
+
235
+ def _get_node_by_id(self, node_id):
236
+ splits = node_id.split(Const.SEP, 1)
237
+ if len(splits) < 2 or not splits[0].isdigit():
238
+ logger.error(f'invalid node_id {node_id}')
239
+ raise RuntimeError(f'invalid node_id {node_id}')
240
+ rank = int(splits[0])
241
+ return self._rank_comm_nodes_dict.get(rank, {}).get(node_id)
242
+
243
+
244
+ def _nan_analyze_parser(parser):
245
+ parser.add_argument("-i", "--input_path", dest="input_path", default="", type=str,
246
+ help="<Required> The dump file path, over step level. eg: \"xxx/step_0/\".",
247
+ required=True)
248
+ parser.add_argument("-o", "--output_path", dest="output_path", default="./output", type=str,
249
+ help="<optional> The nan inf analyze result output file path.",
250
+ required=False)
251
+
252
+
253
+ def _run_nan_analyze(args):
254
+ check_file_or_directory_path(args.input_path, True)
255
+ NanAnalyzer(args.input_path, args.output_path).analyze()
@@ -0,0 +1,189 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.common.log import logger
19
+ from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst
20
+
21
+
22
+ @dataclass
23
+ class DataNode:
24
+ op_name: str
25
+ rank: int
26
+ inputs: list
27
+ input_args: list
28
+ input_kwargs: dict
29
+ outputs: dict
30
+ layer: int = 0 # 和communication_node的layer保持一致
31
+ sub_layer: int = 0 # 调用顺序,越小表示越先调用
32
+
33
+ def __init__(self, op_name, rank, op_data, **kwargs):
34
+ self.op_name = op_name
35
+ self.rank = rank
36
+ self.inputs = op_data.get(Const.INPUT, [])
37
+ self.input_args = op_data.get(Const.INPUT_ARGS, [])
38
+ self.input_kwargs = op_data.get(Const.INPUT_KWARGS, {})
39
+ self.outputs = op_data.get(Const.OUTPUT, {})
40
+ self.sub_layer = kwargs.get('sub_layer', 0)
41
+
42
+ @staticmethod
43
+ def find_complete_construct(construct_info, op_name):
44
+ construct = [op_name]
45
+ seen = set(op_name)
46
+ while True:
47
+ op_name = construct_info.get(op_name)
48
+ if not op_name or op_name in seen:
49
+ return construct
50
+ construct.insert(0, op_name)
51
+ seen.add(op_name)
52
+
53
+ def find_stack(self, stack_info):
54
+ for item in stack_info.values():
55
+ if len(item) >= 2 and self.op_name in item[0]:
56
+ return item[1]
57
+ return {}
58
+
59
+ def is_anomaly(self) -> bool:
60
+ if is_ignore_op(self.op_name):
61
+ return False
62
+ is_input_anomaly = (check_item_anomaly(self.inputs) or check_item_anomaly(self.input_args) or
63
+ check_item_anomaly(self.input_kwargs))
64
+ is_output_anomaly = check_item_anomaly(self.outputs)
65
+ return (not is_input_anomaly) and is_output_anomaly
66
+
67
+ def gen_node_info(self, path: RankPath):
68
+ cache = FileCache()
69
+ construct = cache.load_json(path.construct_path)
70
+ stack = cache.load_json(path.stack_path)
71
+ if Const.FORWARD in self.op_name:
72
+ data_info_list = {Const.INPUT_ARGS: self.input_args, Const.INPUT_KWARGS: self.input_kwargs,
73
+ Const.OUTPUT: self.outputs}
74
+ else:
75
+ data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs}
76
+ return {'op_name': self.op_name,
77
+ 'data_info': data_info_list,
78
+ 'construct_info': self.find_complete_construct(construct, self.op_name),
79
+ 'stack_info': self.find_stack(stack)}
80
+
81
+
82
+ class CommunicationNode:
83
+ def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs):
84
+ self.node_id = node_id
85
+ self.rank = rank
86
+ self.data = data
87
+ self.layer = layer
88
+ op_name_split = self.data.op_name.split(Const.SEP)
89
+ if len(op_name_split) < 4:
90
+ logger.error(f'invalid op_name: {self.data.op_name}')
91
+ raise RuntimeError(f'invalid op_name: {self.data.op_name}')
92
+ self.api = op_name_split[1]
93
+ self.call_cnt = op_name_split[2]
94
+ self.pre_node = kwargs.get('pre_node')
95
+ self.link_nodes = kwargs.get('link_nodes', {})
96
+ self.dst_nodes = kwargs.get('dst_nodes', {})
97
+ self.src_nodes = kwargs.get('src_nodes', {})
98
+ self.next_nodes = kwargs.get('next_nodes', {})
99
+ self.compute_ops = kwargs.get('compute_ops', [])
100
+ self.type = self._resolve_type()
101
+ self.connected = False
102
+
103
+ def add_next(self, node):
104
+ self.next_nodes[node.node_id] = node
105
+ node.pre_node = self
106
+ node.layer = self.layer + 1
107
+ node.data.layer = node.layer
108
+
109
+ def add_link(self, node):
110
+ self.link_nodes[node.node_id] = node
111
+ node.link_nodes[self.node_id] = self
112
+ node.layer = self.layer
113
+ node.data.layer = node.layer
114
+ self.connected = True
115
+ node.connected = True
116
+
117
+ def add_dst(self, node):
118
+ self.dst_nodes[node.node_id] = node
119
+ node.src_nodes[self.node_id] = self
120
+ node.layer = self.layer
121
+ node.data.layer = node.layer
122
+ self.connected = True
123
+ node.connected = True
124
+
125
+ def delete(self):
126
+ for node in self.next_nodes.values():
127
+ node.pre_node = None
128
+ for node in self.dst_nodes.values():
129
+ node.src_nodes.pop(self.node_id)
130
+ for node in self.src_nodes.values():
131
+ node.dst_nodes.pop(self.node_id)
132
+ for node in self.link_nodes.values():
133
+ node.link_nodes.pop(self.node_id)
134
+ if self.pre_node:
135
+ self.pre_node.next_nodes.pop(self.node_id)
136
+
137
+ def has_nan_inf(self):
138
+ return self.input_has_nan_inf() or check_item_anomaly(self.data.outputs)
139
+
140
+ def input_has_nan_inf(self):
141
+ return check_item_anomaly(self.data.input_args) or check_item_anomaly(self.data.input_kwargs)
142
+
143
+ def find_connected_nodes(self):
144
+ """
145
+ 根据 api/类型/入参/调用次数 确定相连接的node的op_name
146
+ """
147
+ tar_api = NanAnalyseConst.P2P_API_MAPPING.get(self.api, self.api)
148
+ ranks = set()
149
+ for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]:
150
+ if dst in self.data.input_kwargs:
151
+ dst_value = self.data.input_kwargs.get(dst)
152
+ if dst_value:
153
+ ranks.add(dst_value.get('value'))
154
+ break
155
+ for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]:
156
+ if src in self.data.input_kwargs:
157
+ src_value = self.data.input_kwargs.get(src)
158
+ if src_value:
159
+ ranks.add(src_value.get('value'))
160
+ break
161
+ if not ranks:
162
+ for item in self.data.input_args:
163
+ if isinstance(item, dict) and item.get(Const.TYPE) == 'int':
164
+ ranks.add(item.get('value'))
165
+ group = self.data.input_kwargs.get('group')
166
+ if group:
167
+ ranks.update(group.get('group_ranks'))
168
+ return {'ranks': ranks, 'api': f'Distributed.{tar_api}',
169
+ 'type': NanAnalyseConst.OPPOSITE_DIR.get(self.type, NanAnalyseConst.LINK)}
170
+
171
+ def _resolve_type(self):
172
+ for src in [NanAnalyseConst.SRC, NanAnalyseConst.SRC_GROUP]:
173
+ if src in self.data.input_kwargs and self.data.input_kwargs[src]:
174
+ if self.data.input_kwargs[src].get('value') == self.rank:
175
+ return NanAnalyseConst.SRC
176
+ else:
177
+ return NanAnalyseConst.DST
178
+ for dst in [NanAnalyseConst.DST, NanAnalyseConst.DST_GROUP]:
179
+ if dst in self.data.input_kwargs and self.data.input_kwargs[dst]:
180
+ if self.data.input_kwargs[dst].get('value') == self.rank:
181
+ return NanAnalyseConst.DST
182
+ else:
183
+ return NanAnalyseConst.SRC
184
+ if self.api in NanAnalyseConst.DIRECTED_API:
185
+ for item in self.data.input_args:
186
+ if item.get(Const.TYPE) == 'int':
187
+ node_type = NanAnalyseConst.DIRECTED_API[self.api]
188
+ return node_type if item.get('value') == self.rank else NanAnalyseConst.OPPOSITE_DIR[node_type]
189
+ return NanAnalyseConst.LINK
@@ -0,0 +1,211 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import OrderedDict
17
+ from dataclasses import dataclass
18
+ import sys
19
+ import time
20
+ import psutil
21
+
22
+ from msprobe.core.common.const import CompareConst
23
+ from msprobe.core.common.file_utils import check_file_or_directory_path, load_json
24
+
25
+
26
+ @dataclass
27
+ class RankPath:
28
+ rank: int
29
+ dump_path: str
30
+ construct_path: str
31
+ stack_path: str
32
+
33
+ def __init__(self, rank, dump_path, construct_path, stack_path):
34
+ self.rank = rank
35
+ check_file_or_directory_path(dump_path)
36
+ self.dump_path = dump_path
37
+ check_file_or_directory_path(construct_path)
38
+ self.construct_path = construct_path
39
+ check_file_or_directory_path(stack_path)
40
+ self.stack_path = stack_path
41
+
42
+
43
+ class FileCache:
44
+ """
45
+ lazy load file
46
+ """
47
+ _instance = None
48
+
49
+ def __new__(cls, *args, **kwargs):
50
+ if not cls._instance:
51
+ cls._instance = super().__new__(cls, *args, **kwargs)
52
+ return cls._instance
53
+
54
+ def __init__(self):
55
+ self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4
56
+ self._cache = OrderedDict()
57
+ self._access_cnt = {}
58
+ self._access_time = {}
59
+ self._size = {}
60
+
61
+ @staticmethod
62
+ def _sizeof(obj):
63
+ seen = set()
64
+ objs = [obj]
65
+ size = 0
66
+ while objs:
67
+ obj = objs.pop()
68
+ obj_id = id(obj)
69
+ if obj_id in seen:
70
+ continue
71
+ seen.add(obj_id)
72
+ size += sys.getsizeof(obj)
73
+ if isinstance(obj, dict):
74
+ objs.extend(obj.keys())
75
+ objs.extend(obj.values())
76
+ elif isinstance(obj, (list, tuple, set, frozenset)):
77
+ objs.extend(obj)
78
+ return size
79
+
80
+ def load_json(self, json_path):
81
+ if json_path in self._cache:
82
+ self._access_cnt[json_path] += 1
83
+ self._access_time[json_path] = time.monotonic()
84
+ self._cache.move_to_end(json_path)
85
+ return self._cache[json_path]
86
+ self._cleanup()
87
+ return self._load(json_path)
88
+
89
+ def _load(self, json_path):
90
+ data = load_json(json_path)
91
+ self._add_to_cache(json_path, data)
92
+ return data
93
+
94
+ def _add_to_cache(self, key, data):
95
+ if key in self._cache:
96
+ self._cache.move_to_end(key)
97
+ else:
98
+ self._cache[key] = data
99
+ self._access_cnt[key] = 0
100
+ self._access_time[key] = time.monotonic()
101
+ self._size[key] = self._sizeof(data)
102
+
103
+ def _calc_cache_size(self):
104
+ return sys.getsizeof(self._cache) + sum(self._size.values())
105
+
106
+ def _cleanup(self):
107
+ while self._calc_cache_size() > self._max_memory_usage and self._cache:
108
+ least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k])
109
+ least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k])
110
+ largest_key = max(self._cache.keys(), key=lambda k: self._size[k])
111
+ key_to_rm = min([least_frequent_key, least_recent_key, largest_key],
112
+ key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k]))
113
+ del self._cache[key_to_rm]
114
+ del self._access_cnt[key_to_rm]
115
+ del self._access_time[key_to_rm]
116
+ del self._size[key_to_rm]
117
+
118
+
119
+ def is_communication_op(op_name):
120
+ # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等
121
+ # 从wrap文件中读取,先硬编码在文件中
122
+ return (op_name.startswith('Distributed.') and
123
+ any(keyword in op_name for keyword in NanAnalyseConst.COMMUNICATION_KEYWORDS))
124
+
125
+
126
+ def is_ignore_op(op_name):
127
+ ignore_keywords = [
128
+ 'Torch.empty',
129
+ 'Torch.fill'
130
+ ]
131
+ return any(keyword in op_name for keyword in ignore_keywords)
132
+
133
+
134
+ def check_item_anomaly(param):
135
+ def has_nan_inf(dict_obj, key):
136
+ return str(dict_obj.get(key)).lower() in CompareConst.OVERFLOW_LIST
137
+
138
+ items = []
139
+ if isinstance(param, list):
140
+ items = param
141
+ elif isinstance(param, dict):
142
+ items = param.values()
143
+ for item in items:
144
+ if not isinstance(item, dict):
145
+ continue
146
+ if has_nan_inf(item, 'Max') or has_nan_inf(item, 'Min'):
147
+ return True
148
+ return False
149
+
150
+
151
+ def analyze_anomaly_in_group(nodes_group):
152
+ anomaly_nodes = []
153
+
154
+ def get_compute_ops_from_comm_nodes(comm_nodes):
155
+ for comm_node in comm_nodes:
156
+ for op_node in comm_node.compute_ops:
157
+ op_node.layer = comm_node.layer
158
+ anomaly_nodes.append(op_node)
159
+
160
+ def get_comm_ops(comm_nodes):
161
+ for node in comm_nodes:
162
+ node.data.layer = node.layer
163
+ anomaly_nodes.append(node.data)
164
+
165
+ # 先看src或link中input是否有异常
166
+ src_list = list(filter(lambda node: node.type in [NanAnalyseConst.SRC, NanAnalyseConst.LINK], nodes_group))
167
+ input_anomaly_nodes = list(filter(lambda node: node.input_has_nan_inf(), src_list))
168
+ # 如果有异常回溯计算节点找到异常来源
169
+ # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
170
+ get_compute_ops_from_comm_nodes(input_anomaly_nodes)
171
+ # 筛选入参没问题但出参有问题的通信节点
172
+ output_anomaly_nodes = list(filter(lambda node: node.data.is_anomaly(), nodes_group))
173
+ get_comm_ops(output_anomaly_nodes)
174
+ return anomaly_nodes
175
+
176
+
177
+ class NanAnalyseConst:
178
+ COMMUNICATION_KEYWORDS = {
179
+ 'send', # send 算子
180
+ 'recv', # recv 算子
181
+ 'broadcast', # broadcast 算子
182
+ 'all_reduce', # all_reduce 算子
183
+ 'reduce', # reduce 算子
184
+ 'all_gather', # all_gather 算子
185
+ 'gather', # gather 算子
186
+ 'isend', # isend 算子
187
+ 'irecv', # irecv 算子
188
+ 'scatter', # scatter 算子
189
+ 'reduce_scatter', # reduce_scatter 算子
190
+ '_reduce_scatter_base', # _reduce_scatter_base 算子
191
+ '_all_gather_base', # _all_gather_base 算子
192
+ 'all_to_all_single', # all_to_all_single 算子
193
+ 'all_to_all', # all_to_all 算子
194
+ 'all_gather_into_tensor', # all_gather_into_tensor 算子
195
+ 'reduce_scatter_tensor', # reduce_scatter_tensor 算子
196
+ 'send_object_list', # send_object_list 算子
197
+ 'recv_object_list' # recv_object_list 算子
198
+ }
199
+ P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend',
200
+ 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'}
201
+ SRC = 'src'
202
+ DST = 'dst'
203
+ SRC_GROUP = 'src_group'
204
+ DST_GROUP = 'dst_group'
205
+ LINK = 'link'
206
+ DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC,
207
+ 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC}
208
+ OPPOSITE_DIR = {SRC: DST, DST: SRC}
209
+ DUMP_FILE = "dump.json"
210
+ CONSTRUCT_FILE = "construct.json"
211
+ STACK_FILE = "stack.json"
@@ -125,8 +125,8 @@ class CheckerConfig:
125
125
  save_error_data=config_params.get('save_error_data'),
126
126
  is_continue_run_ut=config_params.get('is_continue_run_ut'),
127
127
  real_data_path=config_params.get('real_data_path'),
128
- white_list=self.white_list,
129
- black_list=self.black_list,
128
+ white_list=self.white_list.copy() if self.white_list else [],
129
+ black_list=self.black_list.copy() if self.black_list else [],
130
130
  error_data_path=config_params.get('error_data_path'),
131
131
  online_config=self.get_online_config()
132
132
  )