mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -0,0 +1,135 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from tqdm import tqdm
19
+
20
+ from msprobe.core.common.const import Const, CompareConst
21
+ from msprobe.core.common.utils import logger, CompareException
22
+ from msprobe.core.common.file_utils import load_yaml
23
+ from msprobe.core.compare.config import ModeConfig
24
+ from msprobe.core.compare.utils import gen_api_batches
25
+
26
+
27
+ cur_dir = os.path.dirname(os.path.realpath(__file__))
28
+ diff_threshold_yaml_path = os.path.join(cur_dir, 'diff_analyze_threshold.yaml')
29
+ ignore_op_list_yaml_path = os.path.join(cur_dir, 'ignore_op_list.yaml')
30
+ ignore_list = load_yaml(ignore_op_list_yaml_path)
31
+ thresholds = load_yaml(diff_threshold_yaml_path)
32
+ cmp_metrics = thresholds.get('compare_metrics')
33
+
34
+
35
+ class FirstDiffAnalyze:
36
+ def __init__(self, mode_config: ModeConfig, rank):
37
+ self.mode_config = mode_config
38
+ self.rank = rank
39
+
40
+ @staticmethod
41
+ def single_metric_diff_check(cmp_metric, metric_value):
42
+ threshold = thresholds.get(cmp_metric, None)
43
+ if threshold is None:
44
+ logger.error(f"Check diff or {cmp_metric} need to configure the threshold. "
45
+ f"Please configure it in 'diff_analyze_threshold.yaml'.")
46
+ raise CompareException(CompareException.MISSING_THRESHOLD_ERROR)
47
+ if not isinstance(threshold, list) or len(threshold) != 1:
48
+ logger.error(f"{cmp_metric} threshold configure wrong. Please check.")
49
+ raise CompareException(CompareException.WRONG_THRESHOLD_ERROR)
50
+ if isinstance(metric_value, str) and metric_value.endswith('%'):
51
+ metric_value_float = float(metric_value[:-1]) / 100
52
+ if metric_value_float > threshold[0]:
53
+ return True
54
+ return False
55
+
56
+ def single_api_check(self, result_slice, header, api_name=None):
57
+ """
58
+ 单个api差异检查
59
+
60
+ :param result_slice: 数据切片
61
+ :param header: 列名列表
62
+ :return: {'is_same': bool, 'op_items': list[dict]}
63
+ """
64
+ single_check_result = {
65
+ 'is_same': True,
66
+ 'op_items': []
67
+ }
68
+
69
+ column_indices = {name: idx for idx, name in enumerate(header)}
70
+ output_idx = -1
71
+ for line in result_slice:
72
+ op_item = {
73
+ column_name: line[column_indices[column_name]]
74
+ for column_name in header
75
+ }
76
+ single_check_result['op_items'].append(op_item)
77
+ if op_item['state'] != 'output':
78
+ continue
79
+ output_idx += 1
80
+ if output_idx in ignore_list.get(api_name, []):
81
+ continue
82
+ # set is_same
83
+ if self.mode_config.dump_mode == Const.MD5:
84
+ if line[column_indices[CompareConst.RESULT]] == CompareConst.DIFF:
85
+ single_check_result['is_same'] = False
86
+ else:
87
+ for cmp_metric in cmp_metrics:
88
+ metric_value = line[column_indices[cmp_metric]]
89
+ if self.single_metric_diff_check(cmp_metric, metric_value):
90
+ single_check_result['is_same'] = False
91
+ break
92
+ return single_check_result
93
+
94
+ def check(self, result_df):
95
+ """
96
+ 比对后循环遍历api检查差异
97
+ example:
98
+ {
99
+ 'Functional.conv2d.0.forward': {
100
+ 'is_same': true,
101
+ 'op_items': [
102
+ {
103
+ 'NPU name': 'Functional.conv2d.0.forward.input.0',
104
+ 'Bench name': 'Functional.conv2d.0.forward.input.0',
105
+ 'xxx': 1,
106
+ 'NormRelativeErr': 2,
107
+ 'yyy': 3,
108
+ ...
109
+ }
110
+ ]
111
+ }
112
+ }
113
+ """
114
+ result = result_df.values
115
+ header = result_df.columns.tolist()
116
+
117
+ api_batches = gen_api_batches(result, header)
118
+
119
+ check_result = {}
120
+
121
+ default_bar_desc = 'API/Module diff check Progress'
122
+ bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc
123
+ with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="api/module", ncols=100) as progress_bar:
124
+ for api_batch in api_batches:
125
+ result_slice = result[api_batch.start: api_batch.params_grad_end_index]
126
+ api_compo = api_batch.api_name.split('.')
127
+ # suppose name is Tensor.MatMul.0.forward
128
+ if len(api_compo) < 4:
129
+ continue
130
+ # get MatMul as api_name
131
+ api_name = api_compo[-3]
132
+ check_result[api_batch.api_name] = self.single_api_check(result_slice, header, api_name)
133
+ progress_bar.update(1)
134
+
135
+ return check_result
@@ -0,0 +1,3 @@
1
+ npu_fusion_attention:
2
+ - 4
3
+ - 5
File without changes
@@ -0,0 +1,282 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import time
17
+ from collections import defaultdict
18
+ import os
19
+ from itertools import dropwhile, chain
20
+
21
+ from msprobe.core.common import const
22
+ from msprobe.core.common.file_utils import check_file_or_directory_path, save_json, make_dir
23
+ from msprobe.core.common.log import logger
24
+ from msprobe.core.common.const import Const
25
+ from msprobe.core.compare.find_first.data_processor import DataProcessor
26
+ from msprobe.core.compare.find_first.utils import (RankPath, FileCache, is_communication_op, is_ignore_op,
27
+ DiffAnalyseConst, analyze_diff_in_group)
28
+ from msprobe.core.compare.find_first.graph import DataNode, CommunicationNode
29
+
30
+
31
+ class DiffAnalyzer:
32
+ def __init__(self, npu_path, bench_path, output_path, data_frame=Const.PT_FRAMEWORK):
33
+ self._bench_path = bench_path
34
+ self._npu_path = npu_path
35
+ self._output_path = output_path
36
+ self.pre_processor = DataProcessor(data_frame)
37
+ self._paths = {}
38
+ self._diff_nodes = [] # 记录所有异常节点
39
+ self._cache = FileCache()
40
+ self._first_comm_nodes = {} # 记录各rank下首个通信节点的node_id
41
+ self._after_comm_diffs = {} # 记录各rank下发生在通信节点之后的异常计算节点
42
+ self._rank_comm_nodes_dict = {} # 记录各rank的通信节点
43
+
44
+ def analyze(self):
45
+ self._pre_process()
46
+ for analyze_func in [self._pre_analyze, self._analyze, self._post_analyze]:
47
+ analyze_func()
48
+ if self._diff_nodes:
49
+ self._gen_analyze_info()
50
+ return
51
+ logger.info('Cannot find any diff node, no need to generate analyze file.')
52
+
53
+ def _pre_process(self):
54
+ self.pre_processor.process(self._npu_path, self._bench_path, self._output_path)
55
+ self._resolve_input_path(self._output_path)
56
+ logger.info("Pre Process completed.")
57
+
58
+ """
59
+ 这里需要生成stack,但是直接用dict中自带就行,在op_items.NPU_Stack_Info中
60
+ """
61
+ def _resolve_input_path(self, result_input_path):
62
+ contents = os.listdir(result_input_path)
63
+ rank_paths = {}
64
+
65
+ for path in contents:
66
+ # 检查文件名是否符合compare_result_rank{rank_id}_{timestamp}.json格式
67
+ if not path.startswith('compare_result_rank'):
68
+ continue
69
+ if not path.endswith('.json'):
70
+ continue
71
+
72
+ # 从文件名中提取rank_id
73
+ try:
74
+ path_ele_list = path.split('_')
75
+ if len(path_ele_list) <= 2:
76
+ continue
77
+ rank_part = path_ele_list[2]
78
+ if not rank_part.startswith('rank'):
79
+ continue
80
+ rank_str = rank_part.strip('rank') # 去掉'rank'前缀
81
+ rank = int(rank_str) if rank_str else 0
82
+ except (IndexError, ValueError):
83
+ continue
84
+
85
+ # 构建完整的json文件路径
86
+ dump_path = os.path.join(result_input_path, path)
87
+ rank_paths[rank] = RankPath(rank, dump_path)
88
+
89
+ # 按照rank id排序后添加到self._paths中
90
+ for rank in sorted(rank_paths.keys()):
91
+ self._paths[rank] = rank_paths[rank]
92
+
93
+ def _pre_analyze(self):
94
+ logger.info('Start searching diff node before communication.')
95
+ for path in self._paths.values():
96
+ dump_data = self._cache.load_json(path.dump_path)
97
+ if not dump_data:
98
+ logger.warning(f'Rank {path.rank} has no dump data!')
99
+ continue
100
+ for op_name, op_data in dump_data.items():
101
+ if is_ignore_op(op_name):
102
+ continue
103
+ if is_communication_op(op_name):
104
+ self._first_comm_nodes[path.rank] = op_name
105
+ break
106
+ data_node = DataNode(op_name, path.rank, op_data)
107
+ if data_node.is_diff:
108
+ self._diff_nodes.append(data_node)
109
+ break
110
+
111
+ def _analyze(self):
112
+ logger.info('Start searching diff node during communication.')
113
+ self._rank_comm_nodes_dict = {rank: self._analyze_comm_nodes(rank) for rank in self._paths}
114
+ self._connect_comm_nodes()
115
+ self._pruning()
116
+ self._search_first_diff()
117
+
118
+ def _post_analyze(self):
119
+ logger.info('Start searching diff node after communication.')
120
+ for nodes in self._after_comm_diffs.values():
121
+ if nodes:
122
+ self._diff_nodes.append(nodes[0])
123
+
124
+ def _connect_comm_nodes(self):
125
+ searched_ranks = set()
126
+ for rank, nodes in list(self._rank_comm_nodes_dict.items())[:-1]:
127
+ searched_ranks.add(rank)
128
+ seen_nodes = set()
129
+ last_node = None
130
+ for cur_node in nodes.values():
131
+ is_overflow = last_node and hasattr(last_node, 'layer') and hasattr(cur_node, 'layer') and \
132
+ last_node.layer >= cur_node.layer
133
+ if is_overflow:
134
+ cur_node.layer = last_node.layer + 1
135
+ conn_info = cur_node.find_connected_nodes()
136
+ if not conn_info.get('ranks'):
137
+ conn_info['ranks'] = self._rank_comm_nodes_dict.keys()
138
+ last_node = cur_node
139
+ if not self._find_connection(conn_info, cur_node, searched_ranks, seen_nodes):
140
+ logger.debug(f'Cannot find connected communication node for "{cur_node.node_id}".')
141
+
142
+ def _find_connection(self, conn_info, cur_node, searched_ranks, seen_nodes):
143
+ def connect(search_node):
144
+ seen_nodes.add(search_node.node_id)
145
+ if search_node.type == DiffAnalyseConst.DST:
146
+ cur_node.add_dst(search_node)
147
+ elif search_node.type == DiffAnalyseConst.SRC:
148
+ search_node.layer = cur_node.layer
149
+ search_node.add_dst(cur_node)
150
+ else:
151
+ cur_node.add_link(search_node)
152
+
153
+ found = cur_node.connected
154
+ for connected_rank in conn_info['ranks']:
155
+ if connected_rank in searched_ranks:
156
+ continue
157
+ tar_id_prefix = f'{connected_rank}.{conn_info["api"]}'
158
+ for search_id, search_node in self._rank_comm_nodes_dict[connected_rank].items():
159
+ if search_id in seen_nodes:
160
+ continue
161
+ if not (search_id.startswith(tar_id_prefix) and search_node.type == conn_info.get('type')):
162
+ continue
163
+ search_conn_ranks = search_node.find_connected_nodes().get('ranks')
164
+ if ((not search_conn_ranks and search_node.api not in DiffAnalyseConst.DIRECTED_API) or
165
+ cur_node.rank in search_conn_ranks): # 有些无向通信算子没有填ProcessGroup,默认连接所有rank
166
+ connect(search_node)
167
+ found = True
168
+ break
169
+ return found
170
+
171
+ def _analyze_comm_nodes(self, rank):
172
+ path = self._paths[rank]
173
+ data = self._cache.load_json(path.dump_path)
174
+ communication_nodes = {}
175
+ if rank not in self._first_comm_nodes: # 此rank没有通信节点
176
+ return communication_nodes
177
+ last_node_id = None # 记录上一个通信节点的node_id
178
+ compute_ops = [] # 记录两个通信节点之间的计算节点
179
+ sub_layer = 0 # 记录两个通信算子之间异常计算节点的调用序数
180
+ for op_name in dropwhile(lambda k: k != self._first_comm_nodes[rank], data):
181
+ node_id = f'{rank}.{op_name}'
182
+ op_data = data[op_name]
183
+ if is_communication_op(op_name):
184
+ comm_node = CommunicationNode(node_id, rank, DataNode(op_name, rank, op_data, sub_layer=sub_layer),
185
+ compute_ops=compute_ops)
186
+ if last_node_id:
187
+ communication_nodes.get(last_node_id).add_next(comm_node)
188
+ communication_nodes[node_id] = comm_node
189
+ last_node_id = node_id
190
+ compute_ops = []
191
+ sub_layer = 0
192
+ elif not is_ignore_op(op_name):
193
+ data_node = DataNode(op_name, rank, op_data, sub_layer=sub_layer)
194
+ if data_node.is_diff:
195
+ compute_ops.append(data_node)
196
+ sub_layer += 1
197
+ if compute_ops:
198
+ self._after_comm_diffs[rank] = compute_ops
199
+ return communication_nodes
200
+
201
+ def _pruning(self):
202
+ deleted_node_id = []
203
+ for nodes in self._rank_comm_nodes_dict.values():
204
+ for node_id in list(nodes.keys()):
205
+ node = nodes[node_id]
206
+ if node.is_diff or node.compute_ops:
207
+ continue
208
+ deleted_node_id.append(node_id)
209
+ node.delete()
210
+ del nodes[node_id]
211
+ logger.debug(f'After pruning, following nodes are removed: [{", ".join(deleted_node_id)}]')
212
+
213
+ def _search_first_diff(self):
214
+ nodes_queues = []
215
+ for comm_nodes in self._rank_comm_nodes_dict.values():
216
+ nodes_queues.append(sorted(list(comm_nodes.values()), key=lambda x: x.layer))
217
+ seen_nodes = set()
218
+
219
+ def get_next_node(node_list):
220
+ while node_list:
221
+ next_node = node_list.pop(0)
222
+ if next_node.node_id not in seen_nodes:
223
+ return next_node
224
+ return None
225
+
226
+ def find_all_members(ori_node):
227
+ ids = get_relative_ids(ori_node)
228
+ id_queue = list(chain(*[get_relative_ids(self._get_node_by_id(n_id)).difference(ids) for n_id in ids]))
229
+ while id_queue:
230
+ new_id = id_queue.pop(0)
231
+ ids.add(new_id)
232
+ id_queue.extend(get_relative_ids(self._get_node_by_id(new_id)).difference(ids))
233
+ return ids
234
+
235
+ def get_relative_ids(ori_node):
236
+ if not ori_node:
237
+ return set()
238
+ return ({ori_node.node_id} | ori_node.link_nodes.keys() | ori_node.src_nodes.keys() |
239
+ ori_node.dst_nodes.keys())
240
+
241
+ while any(nodes_queues):
242
+ groups = []
243
+ all_ids_in_groups = set()
244
+ for nodes in nodes_queues:
245
+ node = get_next_node(nodes)
246
+ if not node:
247
+ continue
248
+ if not groups or node.node_id not in all_ids_in_groups:
249
+ new_group = find_all_members(node)
250
+ groups.append(new_group)
251
+ all_ids_in_groups.update(new_group)
252
+ for group in groups:
253
+ seen_nodes.update(group)
254
+ self._diff_nodes.extend(analyze_diff_in_group([self._get_node_by_id(n_id) for n_id in group]))
255
+ if self._diff_nodes:
256
+ # 找出所有layer和sub_layer最小的节点
257
+ min_layer_sublayer = min((x.layer, x.sub_layer) for x in self._diff_nodes)
258
+ self._diff_nodes = [
259
+ node
260
+ for node in self._diff_nodes
261
+ if (node.layer, node.sub_layer) == min_layer_sublayer
262
+ ]
263
+ return
264
+
265
+ def _get_node_by_id(self, node_id):
266
+ splits = node_id.split(Const.SEP, 1)
267
+ if len(splits) < 2 or not splits[0].isdigit():
268
+ logger.error(f'invalid node_id {node_id}')
269
+ raise RuntimeError(f'invalid node_id {node_id}')
270
+ rank = int(splits[0])
271
+ return self._rank_comm_nodes_dict.get(rank, {}).get(node_id)
272
+
273
+ def _gen_analyze_info(self):
274
+ if not os.path.exists(self._output_path):
275
+ make_dir(self._output_path)
276
+ file_name = f'diff_analyze_{time.time_ns()}.json'
277
+ result_file = os.path.join(self._output_path, file_name)
278
+ result_content = defaultdict(list)
279
+ for node in self._diff_nodes:
280
+ result_content[f'rank_{node.rank}'].append(node.gen_node_info(self._paths[node.rank]))
281
+ save_json(result_file, result_content, 2)
282
+ logger.info(f"The analyze result is saved in: {result_file}")
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.common.log import logger
19
+
20
+
21
+ class DataProcessor:
22
+ def __init__(self, data_frame):
23
+ self.data_frame = data_frame
24
+ if self.data_frame == Const.PT_FRAMEWORK:
25
+ from msprobe.pytorch.compare.distributed_compare import compare_distributed
26
+ self.process_func = compare_distributed
27
+ elif self.data_frame == Const.MS_FRAMEWORK:
28
+ from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed
29
+ self.process_func = ms_compare_distributed
30
+ else:
31
+ raise ValueError(f"Unsupported data_frame: {self.data_frame}")
32
+
33
+ def process(self, npu_path, bench_path, output_path):
34
+ logger.info("Start comparing data ......")
35
+ return self.process_func(npu_path, bench_path, output_path, first_diff_analyze=True)
@@ -0,0 +1,188 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from msprobe.core.common.const import Const
18
+ from msprobe.core.common.log import logger
19
+ from msprobe.core.common.const import CompareConst
20
+ from msprobe.core.compare.find_first.utils import RankPath, DiffAnalyseConst
21
+
22
+
23
+ @dataclass
24
+ class DataNode:
25
+ op_name: str
26
+ rank: int
27
+ inputs: dict
28
+ outputs: dict
29
+ op_data: list
30
+ layer: int = 0 # 和communication_node的layer保持一致
31
+ sub_layer: int = 0 # 调用顺序,越小表示越先调用
32
+
33
+ def __init__(self, op_name, rank, op_data, **kwargs):
34
+ self.op_name = op_name
35
+ self.rank = rank
36
+ self.stack = None
37
+ self.inputs = {}
38
+ self.outputs = {}
39
+ self.is_diff = False
40
+ self.parse_data(op_data)
41
+ self.sub_layer = kwargs.get('sub_layer', 0)
42
+
43
+ def find_stack(self):
44
+ for item in self.stack:
45
+ if len(item) >= 2 and self.op_name in item[0]:
46
+ return item[1]
47
+ return {}
48
+
49
+ def parse_data(self, op_data):
50
+ self.is_diff = not op_data.get("is_same", True)
51
+ self.op_data = op_data.get("op_items") # 这里拿到的是比对column,是一个list,有若干行
52
+ metrics = {}
53
+ for cmp_data in self.op_data:
54
+ name = cmp_data.get(CompareConst.NPU_NAME)
55
+ # 构建度量指标字典
56
+ metrics = {}
57
+
58
+ if CompareConst.NPU_MAX in cmp_data:
59
+ metrics = {CompareConst.NPU_MAX: cmp_data.get(CompareConst.NPU_MAX),
60
+ CompareConst.NPU_MIN: cmp_data.get(CompareConst.NPU_MIN),
61
+ CompareConst.NPU_MEAN: cmp_data.get(CompareConst.NPU_MEAN),
62
+ CompareConst.NPU_NORM: cmp_data.get(CompareConst.NPU_NORM)}
63
+ elif CompareConst.NPU_MD5 in cmp_data:
64
+ metrics[CompareConst.NPU_MD5] = cmp_data.get(CompareConst.NPU_MD5)
65
+
66
+ if CompareConst.NPU_P2POP_PEER in cmp_data:
67
+ metrics[CompareConst.NPU_P2POP_PEER] = cmp_data.get(CompareConst.NPU_P2POP_PEER)
68
+
69
+ if cmp_data.get(CompareConst.STACK) != CompareConst.N_A and not self.stack:
70
+ self.stack = cmp_data.get(CompareConst.STACK)
71
+ if cmp_data.get('state') == "input":
72
+ self.inputs[name] = metrics
73
+ elif cmp_data.get('state') == "output":
74
+ self.outputs[name] = metrics
75
+
76
+ def gen_node_info(self, path: RankPath):
77
+ data_info_list = {Const.INPUT: self.inputs, Const.OUTPUT: self.outputs}
78
+ return {'op_name': self.op_name,
79
+ 'data_info': data_info_list,
80
+ 'stack_info': self.stack}
81
+
82
+
83
+ class CommunicationNode:
84
+ def __init__(self, node_id, rank, data: DataNode, layer=0, **kwargs):
85
+ self.node_id = node_id
86
+ self.rank = rank
87
+ self.data = data
88
+ self.is_diff = data.is_diff
89
+ self.layer = layer
90
+ op_name_split = self.data.op_name.split(Const.SEP)
91
+ if len(op_name_split) < 4:
92
+ logger.error(f'invalid op_name: {self.data.op_name}')
93
+ raise RuntimeError(f'invalid op_name: {self.data.op_name}')
94
+ self.api = op_name_split[1]
95
+ self.call_cnt = op_name_split[2]
96
+ self.pre_node = kwargs.get('pre_node')
97
+ self.link_nodes = kwargs.get('link_nodes', {})
98
+ self.dst_nodes = kwargs.get('dst_nodes', {})
99
+ self.src_nodes = kwargs.get('src_nodes', {})
100
+ self.next_nodes = kwargs.get('next_nodes', {})
101
+ self.compute_ops = kwargs.get('compute_ops', [])
102
+ self.type = self._resolve_type()
103
+ self.connected = False
104
+
105
+ def add_next(self, node):
106
+ self.next_nodes[node.node_id] = node
107
+ node.pre_node = self
108
+ node.layer = self.layer + 1
109
+ node.data.layer = node.layer
110
+
111
+ def add_link(self, node):
112
+ self.link_nodes[node.node_id] = node
113
+ node.link_nodes[self.node_id] = self
114
+ node.layer = self.layer
115
+ node.data.layer = node.layer
116
+ self.connected = True
117
+ node.connected = True
118
+
119
+ def add_dst(self, node):
120
+ self.dst_nodes[node.node_id] = node
121
+ node.src_nodes[self.node_id] = self
122
+ node.layer = self.layer
123
+ node.data.layer = node.layer
124
+ self.connected = True
125
+ node.connected = True
126
+
127
+ def delete(self):
128
+ for node in self.next_nodes.values():
129
+ node.pre_node = None
130
+ for node in self.dst_nodes.values():
131
+ if node.src_nodes:
132
+ node.src_nodes.pop(self.node_id)
133
+ for node in self.src_nodes.values():
134
+ if node.dst_nodes:
135
+ node.dst_nodes.pop(self.node_id)
136
+ for node in self.link_nodes.values():
137
+ if node.link_nodes:
138
+ node.link_nodes.pop(self.node_id)
139
+ if self.pre_node:
140
+ if self.pre_node.next_nodes:
141
+ self.pre_node.next_nodes.pop(self.node_id)
142
+
143
+ def find_connected_nodes(self):
144
+ """
145
+ 根据 api/类型/入参/调用次数 确定相连接的node的op_name
146
+ """
147
+ tar_api = DiffAnalyseConst.P2P_API_MAPPING.get(self.api, self.api)
148
+ ranks = set()
149
+ # 遍历DST和SRC相关的input,获取对应的rank值
150
+ # 遍历inputs获取所有rank值
151
+ for input_name, v in self.data.inputs.items():
152
+ # 检查key是否包含DST/SRC相关标识
153
+ target_types = [DiffAnalyseConst.DST, DiffAnalyseConst.DST_GROUP,
154
+ DiffAnalyseConst.SRC, DiffAnalyseConst.SRC_GROUP]
155
+ if any(keyword in input_name for keyword in target_types):
156
+ # 优先使用MD5值,如果没有则使用NPU_MAX值
157
+ rank_val = 0
158
+ if CompareConst.NPU_MD5 in v:
159
+ rank_val = int(v.get(CompareConst.NPU_MD5, 0))
160
+ else:
161
+ rank_val = int(v.get(CompareConst.NPU_MAX, 0))
162
+ if rank_val:
163
+ ranks.add(rank_val)
164
+ elif input_name.endswith('.group'):
165
+ # 优先使用MD5值,如果没有则使用NPU_MAX值
166
+ val = v.get(CompareConst.NPU_MD5) if CompareConst.NPU_MD5 in v else v.get(CompareConst.NPU_MAX)
167
+ if val and val.startswith('[') and val.endswith(']'):
168
+ val = [int(part) for part in val.strip('[]').split(',')]
169
+ ranks.update(val)
170
+ elif v.get(CompareConst.NPU_P2POP_PEER) != "None":
171
+ ranks.add(v.get(CompareConst.NPU_P2POP_PEER))
172
+
173
+ return {'ranks': ranks, 'api': f'Distributed.{tar_api}',
174
+ 'type': DiffAnalyseConst.OPPOSITE_DIR.get(self.type, DiffAnalyseConst.LINK)}
175
+
176
+
177
+ def _resolve_type(self):
178
+ # 遍历SRC和DST相关的输入,根据rank值判断节点类型
179
+ for prefix, node_type in [(DiffAnalyseConst.SRC, DiffAnalyseConst.SRC),
180
+ (DiffAnalyseConst.DST, DiffAnalyseConst.DST)]:
181
+ for k, v in self.data.inputs.items():
182
+ if prefix in k or f"{prefix}_GROUP" in k:
183
+ # 优先使用MD5值,如果没有则使用NPU_MAX值
184
+ compare_val = v.get(CompareConst.NPU_MD5) if CompareConst.NPU_MD5 in v \
185
+ else v.get(CompareConst.NPU_MAX)
186
+ return node_type if compare_val == self.rank \
187
+ else DiffAnalyseConst.OPPOSITE_DIR[node_type]
188
+ return DiffAnalyseConst.LINK