mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -0,0 +1,189 @@
1
+ # Copyright (c) 2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import OrderedDict
17
+ from dataclasses import dataclass
18
+ import sys
19
+ import time
20
+ import psutil
21
+
22
+ from msprobe.core.common.file_utils import check_file_or_directory_path, load_json
23
+ from msprobe.core.common.const import Const
24
+
25
+
26
+ @dataclass
27
+ class RankPath:
28
+ rank: int
29
+ dump_path: str
30
+
31
+ def __init__(self, rank, dump_path):
32
+ self.rank = rank
33
+ check_file_or_directory_path(dump_path)
34
+ self.dump_path = dump_path
35
+
36
+
37
+ class FileCache:
38
+ """
39
+ lazy load file
40
+ """
41
+ _instance = None
42
+
43
+ def __new__(cls, *args, **kwargs):
44
+ if not cls._instance:
45
+ cls._instance = super().__new__(cls, *args, **kwargs)
46
+ return cls._instance
47
+
48
+ def __init__(self):
49
+ self._max_memory_usage = psutil.virtual_memory().available / 4 # 最大占用当前可用内存空间的1/4
50
+ self._cache = OrderedDict()
51
+ self._access_cnt = {}
52
+ self._access_time = {}
53
+ self._size = {}
54
+
55
+ @staticmethod
56
+ def _sizeof(obj):
57
+ seen = set()
58
+ objs = [obj]
59
+ size = 0
60
+ while objs:
61
+ obj = objs.pop()
62
+ obj_id = id(obj)
63
+ if obj_id in seen:
64
+ continue
65
+ seen.add(obj_id)
66
+ size += sys.getsizeof(obj)
67
+ if isinstance(obj, dict):
68
+ objs.extend(obj.keys())
69
+ objs.extend(obj.values())
70
+ elif isinstance(obj, (list, tuple, set, frozenset)):
71
+ objs.extend(obj)
72
+ return size
73
+
74
+ def load_json(self, json_path):
75
+ if json_path in self._cache:
76
+ self._access_cnt[json_path] += 1
77
+ self._access_time[json_path] = time.monotonic()
78
+ self._cache.move_to_end(json_path)
79
+ return self._cache[json_path]
80
+ self._cleanup()
81
+ return self._load(json_path)
82
+
83
+ def _load(self, json_path):
84
+ data = load_json(json_path)
85
+ self._add_to_cache(json_path, data)
86
+ return data
87
+
88
+ def _add_to_cache(self, key, data):
89
+ if key in self._cache:
90
+ self._cache.move_to_end(key)
91
+ else:
92
+ self._cache[key] = data
93
+ self._access_cnt[key] = 0
94
+ self._access_time[key] = time.monotonic()
95
+ self._size[key] = self._sizeof(data)
96
+
97
+ def _calc_cache_size(self):
98
+ return sys.getsizeof(self._cache) + sum(self._size.values())
99
+
100
+ def _cleanup(self):
101
+ while self._calc_cache_size() > self._max_memory_usage and self._cache:
102
+ least_frequent_key = min(self._access_cnt.keys(), key=lambda k: self._access_cnt[k])
103
+ least_recent_key = min(self._access_time.keys(), key=lambda k: self._access_time[k])
104
+ largest_key = max(self._cache.keys(), key=lambda k: self._size[k])
105
+ key_to_rm = min([least_frequent_key, least_recent_key, largest_key],
106
+ key=lambda k: (self._access_cnt[k], self._access_time[k], -self._size[k]))
107
+ del self._cache[key_to_rm]
108
+ del self._access_cnt[key_to_rm]
109
+ del self._access_time[key_to_rm]
110
+ del self._size[key_to_rm]
111
+
112
+
113
+ def is_communication_op(op_name):
114
+ # 定义通信算子的关键字,覆盖各种通信操作,如all_reduce, send, broadcast等
115
+ # 从wrap文件中读取,先硬编码在文件中
116
+ return (op_name.startswith((Const.DISTRIBUTED, Const.MINT_DIST_API_TYPE_PREFIX, Const.MS_API_TYPE_COM)) and
117
+ any(keyword in op_name for keyword in DiffAnalyseConst.COMMUNICATION_KEYWORDS))
118
+
119
+
120
+ def is_ignore_op(op_name):
121
+ ignore_keywords = [
122
+ 'Torch.empty',
123
+ 'Torch.fill',
124
+ 'Tensor.__setitem__'
125
+ ]
126
+ return any(keyword in op_name for keyword in ignore_keywords)
127
+
128
+
129
+ class DiffAnalyseConst:
130
+ COMMUNICATION_KEYWORDS = {
131
+ 'send', # send 算子
132
+ 'recv', # recv 算子
133
+ 'broadcast', # broadcast 算子
134
+ 'all_reduce', # all_reduce 算子
135
+ 'reduce', # reduce 算子
136
+ 'all_gather', # all_gather 算子
137
+ 'gather', # gather 算子
138
+ 'isend', # isend 算子
139
+ 'irecv', # irecv 算子
140
+ 'scatter', # scatter 算子
141
+ 'reduce_scatter', # reduce_scatter 算子
142
+ '_reduce_scatter_base', # _reduce_scatter_base 算子
143
+ '_all_gather_base', # _all_gather_base 算子
144
+ 'all_to_all_single', # all_to_all_single 算子
145
+ 'all_to_all', # all_to_all 算子
146
+ 'all_gather_into_tensor', # all_gather_into_tensor 算子
147
+ 'reduce_scatter_tensor', # reduce_scatter_tensor 算子
148
+ 'send_object_list', # send_object_list 算子
149
+ 'recv_object_list' # recv_object_list 算子
150
+ }
151
+ P2P_API_MAPPING = {'send': 'recv', 'recv': 'send', 'isend': 'irecv', 'irecv': 'isend',
152
+ 'send_object_list': 'recv_object_list', 'recv_object_list': 'send_object_list'}
153
+ SRC = 'src'
154
+ DST = 'dst'
155
+ SRC_GROUP = 'group_src'
156
+ DST_GROUP = 'group_dst'
157
+ LINK = 'link'
158
+ DIRECTED_API = {'send': DST, 'recv': SRC, 'isend': DST, 'irecv': SRC, 'broadcast': SRC, 'scatter': SRC,
159
+ 'gather': DST, 'send_object_list': DST, 'recv_object_list': SRC}
160
+ OPPOSITE_DIR = {SRC: DST, DST: SRC}
161
+ DUMP_FILE = "dump.json"
162
+ CONSTRUCT_FILE = "construct.json"
163
+ STACK_FILE = "stack.json"
164
+
165
+
166
+ def analyze_diff_in_group(nodes_group):
167
+ diff_nodes = []
168
+
169
+ def get_compute_ops_from_comm_nodes(comm_nodes):
170
+ for comm_node in comm_nodes:
171
+ for op_node in comm_node.compute_ops:
172
+ op_node.layer = comm_node.layer
173
+ diff_nodes.append(op_node)
174
+
175
+ def get_comm_ops(comm_nodes):
176
+ for node in comm_nodes:
177
+ node.data.layer = node.layer
178
+ diff_nodes.append(node.data)
179
+
180
+ # 先看src或link中input是否有异常
181
+ src_list = list(filter(lambda node: node.type in [DiffAnalyseConst.SRC, DiffAnalyseConst.LINK], nodes_group))
182
+ input_diff_nodes = list(filter(lambda node: node.is_diff, src_list))
183
+ # 如果有异常回溯计算节点找到异常来源
184
+ # 使用cpu模拟节点进行计算,查看结果是否有问题。需要对所有计算节点录入/映射,暂不实现。
185
+ get_compute_ops_from_comm_nodes(nodes_group)
186
+ # 筛选入参没问题但出参有问题的通信节点
187
+ output_diff_nodes = list(filter(lambda node: node.data.is_diff, nodes_group))
188
+ get_comm_ops(output_diff_nodes)
189
+ return diff_nodes
@@ -16,10 +16,8 @@
16
16
  import abc
17
17
  import math
18
18
  import multiprocessing
19
- import re
20
19
  from collections import namedtuple
21
20
 
22
- import numpy as np
23
21
  import openpyxl
24
22
  from openpyxl.styles import PatternFill
25
23
  from openpyxl.utils.dataframe import dataframe_to_rows
@@ -28,8 +26,8 @@ from tqdm import tqdm
28
26
  from msprobe.core.common.const import CompareConst, Const
29
27
  from msprobe.core.common.file_utils import save_workbook
30
28
  from msprobe.core.common.log import logger
31
- from msprobe.core.common.utils import get_header_index, safe_get_value
32
- from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
29
+ from msprobe.core.common.utils import get_header_index, CompareException
30
+ from msprobe.core.compare.utils import table_value_is_valid, gen_api_batches
33
31
  from msprobe.core.compare.config import ModeConfig
34
32
 
35
33
 
@@ -54,10 +52,12 @@ class CheckOrderMagnitude(HighlightCheck):
54
52
  api_in, api_out, num = info
55
53
  max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
56
54
  else CompareConst.MAX_ABS_ERR, dump_mode)
57
- if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
55
+ max_diff_in = abs(api_in[max_diff_index])
56
+ max_diff_out = abs(api_out[max_diff_index])
57
+ if max_diff_in > max_diff_out or (max_diff_in <= 1 or max_diff_out <= 1):
58
58
  return
59
- in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
60
- out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
59
+ in_order = 0 if max_diff_in < 1 else math.log10(max_diff_in)
60
+ out_order = 0 if max_diff_out < 1 else math.log10(max_diff_out)
61
61
  if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
62
62
  add_highlight_row_info(color_columns.yellow, num,
63
63
  "maximum absolute error of both input/parameters and output exceed 1, "
@@ -102,20 +102,28 @@ class CheckMaxRelativeDiff(HighlightCheck):
102
102
  """检查最大相对差异"""
103
103
 
104
104
  def apply(self, info, color_columns, dump_mode):
105
+ def get_number(data):
106
+ """统计量相对值如果为正常百分数据,str格式并以%结尾"""
107
+ if isinstance(data, str) and data.endswith("%"):
108
+ return float(data[:-1]) / 100
109
+ return data
110
+
105
111
  api_in, api_out, num = info
106
- max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
107
- bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
108
- input_max_relative_diff = np.abs(
109
- np.divide(api_in[max_diff_index], max(Const.FLOAT_EPSILON, api_in[bench_max_index])))
110
- output_max_relative_diff = np.abs(
111
- np.divide(api_out[max_diff_index], max(Const.FLOAT_EPSILON, api_out[bench_max_index])))
112
- if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
113
- (float, int)):
112
+ max_rel_diff = get_header_index(CompareConst.MAX_RELATIVE_ERR, dump_mode)
113
+ input_max_relative_diff = api_in[max_rel_diff] # 内部数据,长度总是和表头一致,不会越界
114
+ output_max_relative_diff = api_out[max_rel_diff]
115
+ input_max_relative_diff = get_number(input_max_relative_diff)
116
+ output_max_relative_diff = get_number(output_max_relative_diff)
117
+
118
+ if not isinstance(output_max_relative_diff, (float, int)):
114
119
  return
115
120
  if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
116
121
  add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
117
- elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
118
- input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
122
+
123
+ if not isinstance(input_max_relative_diff, (float, int)):
124
+ return
125
+ if (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
126
+ input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
119
127
  add_highlight_row_info(color_columns.yellow, num,
120
128
  "The output's maximum relative error exceeds 0.1, "
121
129
  "while the input/parameter's is below 0.01")
@@ -139,12 +147,25 @@ class CheckOverflow(HighlightCheck):
139
147
  add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
140
148
 
141
149
 
150
+ class CheckReqGradConsist(HighlightCheck):
151
+ """检查requires_grad是否一致"""
152
+
153
+ def apply(self, info, color_columns, dump_mode):
154
+ line, num = info
155
+ req_grad_consist_index = get_header_index(CompareConst.REQ_GRAD_CONSIST, dump_mode)
156
+ if not line[req_grad_consist_index]:
157
+ add_highlight_row_info(color_columns.yellow, num, "requires_grad is inconsistent")
158
+
159
+
142
160
  class HighlightRules:
143
161
  """高亮规则集合,用于检查API的误差"""
144
162
  # 适用于每行的规则
145
163
  basic_rules = {
146
164
  "check_overflow": CheckOverflow()
147
165
  }
166
+ consist_rules = {
167
+ "check_req_grad_consist": CheckReqGradConsist()
168
+ }
148
169
 
149
170
  # 用于比较输入和输出的规则
150
171
  # 真实数据检查规则
@@ -160,64 +181,10 @@ class HighlightRules:
160
181
  }
161
182
 
162
183
 
163
- class ApiBatch:
164
- def __init__(self, api_name: str, start: int):
165
- self.api_name = api_name
166
- self.start = start
167
- self.input_len = 1 # input的数量
168
- self.params_end_index = start + 1 # params的结束index
169
- self.output_end_index = start + 1 # output的结束index
170
- self.params_grad_end_index = start + 1 # params_grad的结束index
171
- # 内部state的标志("input", "output", "parameters", "parameters_grad"),
172
- # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index
173
- self._state = Const.INPUT # api_batch初始化为input
174
-
175
- def set_state(self, state: str):
176
- """设置当前状态"""
177
- if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}:
178
- self._state = state
179
- else:
180
- raise ValueError(f"Invalid state: {state}")
181
-
182
- def increment(self, state: str):
183
- self.set_state(state)
184
- if self._state == Const.INPUT or self._state == Const.KWARGS:
185
- self.input_len += 1
186
- self.params_end_index += 1
187
- self.output_end_index += 1
188
- if self._state == Const.PARAMS:
189
- self.params_end_index += 1
190
- self.output_end_index += 1
191
- if self._state == Const.OUTPUT:
192
- self.output_end_index += 1
193
- self.params_grad_end_index += 1
194
-
195
-
196
184
  class HighLight:
197
- def __init__(self, mode_config: ModeConfig):
185
+ def __init__(self, mode_config: ModeConfig, rank):
198
186
  self.mode_config = mode_config
199
-
200
- @staticmethod
201
- def api_batches_update(api_batches, api_name, state, index):
202
- """
203
- 当一个api的所有item更新完后,input, output的索引范围:
204
- input: [start: start+input_len]
205
- output: [start+input_len: output_end_index]
206
- params: [output_end_index: params_end_index]
207
- """
208
- if not api_batches:
209
- api_batches.append(ApiBatch(api_name, index))
210
- else:
211
- api_batch = api_batches[-1]
212
- if api_batch.api_name == api_name or (
213
- not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
214
- try:
215
- api_batch.increment(state)
216
- except ValueError as e:
217
- logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
218
- raise CompareException(CompareException.INVALID_STATE_ERROR) from e
219
- else:
220
- api_batches.append(ApiBatch(api_name, index))
187
+ self.rank = rank
221
188
 
222
189
  @staticmethod
223
190
  def check_indices_numeric(api_items, indices: list):
@@ -232,7 +199,7 @@ class HighLight:
232
199
  if CompareConst.NPU_MD5 in result_df.columns:
233
200
  return
234
201
 
235
- err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
202
+ err_msg = result_df.get(CompareConst.ERROR_MESSAGE).copy()
236
203
  red_lines_num_set = highlight_dict.get('red_rows')
237
204
 
238
205
  for color in ['red', 'yellow']:
@@ -273,12 +240,11 @@ class HighLight:
273
240
  def find_compare_result_error_rows(self, result_df, highlight_dict):
274
241
  """将dataframe根据API分组,并找到有误差的算子用于高亮"""
275
242
  result = result_df.values
276
- api_batches = []
277
- for i, res_i in enumerate(result):
278
- api_full_name = safe_get_value(res_i, 0, "res_i")
279
- api_name, state = get_name_and_state(api_full_name)
280
- self.api_batches_update(api_batches, api_name, state, i)
281
- with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
243
+ header = result_df.columns.tolist()
244
+ api_batches = gen_api_batches(result, header)
245
+ default_bar_desc = 'API/Module Analyse Progress'
246
+ bar_desc_add_rank = f'[{self.rank}]' + default_bar_desc if self.rank else default_bar_desc
247
+ with tqdm(total=len(api_batches), desc=bar_desc_add_rank, unit="item", ncols=100) as progress_bar:
282
248
  for api_batch in api_batches:
283
249
  self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch,
284
250
  highlight_dict)
@@ -328,6 +294,13 @@ class HighLight:
328
294
  api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
329
295
  self.apply_comparison_rules(api_info, color_columns)
330
296
 
297
+ # 对单行API的输入或输出进行requires_grad是否一致判断
298
+ for i, line in enumerate(result):
299
+ index = api_batch_start + i
300
+ line_info = LineInfo(line_data=line, num_pointer=index)
301
+ for rule in HighlightRules.consist_rules.values():
302
+ rule.apply(line_info, color_columns, self.mode_config.dump_mode)
303
+
331
304
  red_lines_num_set = {x[0] for x in red_lines}
332
305
  yellow_lines_num_set = {x[0] for x in yellow_lines}
333
306
  highlight_dict.get('red_rows', set()).update(red_lines_num_set)
@@ -349,28 +322,19 @@ class HighLight:
349
322
 
350
323
  self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
351
324
 
325
+ self.df_malicious_value_check(result_df)
326
+
352
327
  wb = openpyxl.Workbook()
353
328
  ws = wb.active
354
-
355
- # write header
356
- logger.info('Initializing Excel file.')
357
-
358
- self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df)
359
-
360
329
  result_df_convert = result_df.applymap(self.compare_result_df_convert)
361
-
362
330
  for row in dataframe_to_rows(result_df_convert, index=False, header=True):
363
331
  ws.append(row)
364
332
 
365
333
  # 对可疑数据标色
366
334
  logger.info('Coloring Excel in progress.')
335
+ red_fill = PatternFill(start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid")
336
+ yellow_fill = PatternFill(start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid")
367
337
  col_len = len(result_df.columns)
368
- red_fill = PatternFill(
369
- start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
370
- )
371
- yellow_fill = PatternFill(
372
- start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
373
- )
374
338
  for i in highlight_dict.get("red_rows", []):
375
339
  for j in range(1, col_len + 1):
376
340
  ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
@@ -378,7 +342,6 @@ class HighLight:
378
342
  for j in range(1, col_len + 1):
379
343
  ws.cell(row=i + 2, column=j).fill = yellow_fill
380
344
 
381
- logger.info('Saving Excel file to disk: %s' % file_path)
382
345
  save_workbook(wb, file_path)
383
346
 
384
347
  def handle_multi_process_malicious_value_check(self, func, result_df):
@@ -396,22 +359,32 @@ class HighLight:
396
359
 
397
360
  def err_call(args):
398
361
  logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
399
- try:
400
- pool.close()
401
- except OSError:
402
- logger.error("Pool terminate failed")
403
362
 
404
363
  result_df_columns = result_df.columns.tolist()
405
364
  for column in result_df_columns:
406
365
  self.value_check(column)
366
+ async_results = []
407
367
  for df_chunk in chunks:
408
- pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
368
+ result = pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
369
+ async_results.append(result)
409
370
 
410
371
  pool.close()
372
+
373
+ for ar in async_results:
374
+ try:
375
+ ar.get(timeout=3600)
376
+ except Exception as e:
377
+ logger.error(f"Task failed with exception: {e}")
378
+ pool.terminate()
379
+ raise CompareException(CompareException.MULTIPROCESS_ERROR) from e
380
+
411
381
  pool.join()
412
382
 
413
- def df_malicious_value_check(self, df_chunk, result_df_columns):
414
- for row in df_chunk.itertuples(index=False):
383
+ def df_malicious_value_check(self, result_df):
384
+ result_df_columns = result_df.columns.tolist()
385
+ for column in result_df_columns:
386
+ self.value_check(column)
387
+ for row in result_df.itertuples(index=False):
415
388
  api_name = row[0]
416
389
  for i, value in enumerate(row):
417
390
  self.value_check(value, api_name, i, result_df_columns)
@@ -18,12 +18,12 @@ from collections import defaultdict
18
18
 
19
19
  from msprobe.core.common.const import CompareConst, Const
20
20
  from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
21
- from msprobe.core.common.utils import (add_time_with_yaml,
22
- detect_framework_by_dump_json,
23
- get_stack_construct_by_dump_json_path)
21
+ from msprobe.core.common.utils import add_time_with_yaml, detect_framework_by_dump_json, \
22
+ get_stack_construct_by_dump_json_path, CompareException
24
23
  from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
25
24
  from msprobe.core.compare.utils import read_op, reorder_op_name_list
26
25
  from msprobe.core.common.decorator import recursion_depth_decorator
26
+ from msprobe.core.common.log import logger
27
27
 
28
28
 
29
29
  class LayerTrie:
@@ -63,7 +63,11 @@ class LayerTrie:
63
63
  node = node.children[name]
64
64
  if index >= len(node.data_items[state]):
65
65
  return default_value
66
- return node.data_items[state][index]
66
+ if node.data_items[state]:
67
+ return node.data_items[state][index]
68
+ else:
69
+ logger.error(f"node.data_items of state:{state} is empty, please check.")
70
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR)
67
71
 
68
72
  def save_to_yaml(self, output_path):
69
73
  result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
@@ -208,7 +212,8 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa
208
212
  def read_full_op_names(data, op_name):
209
213
  op_parsed_list = read_op(data.get(op_name, {}), op_name)
210
214
  full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list]
211
- return full_op_names
215
+ states = [op_parsed.get(Const.STATE) for op_parsed in op_parsed_list]
216
+ return full_op_names, states
212
217
 
213
218
  def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names):
214
219
  suffix_to_full_op_name = {}
@@ -228,10 +233,10 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa
228
233
  for npu_op_name, bench_op_name in api_mapping.items():
229
234
  if not npu_op_name:
230
235
  continue
231
- npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
232
- bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
233
- npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names)
234
- bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names)
236
+ npu_full_op_names, npu_states = read_full_op_names(npu_data, npu_op_name)
237
+ bench_full_op_names, bench_states = read_full_op_names(bench_data, bench_op_name)
238
+ npu_full_op_names_reorder, _ = reorder_op_name_list(npu_full_op_names, npu_states)
239
+ bench_full_op_names_reorder, _ = reorder_op_name_list(bench_full_op_names, bench_states)
235
240
  mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder,
236
241
  bench_op_name, bench_full_op_names_reorder)
237
242
  data_mapping.update(mapping)
@@ -109,8 +109,8 @@ def check_index_dump_mode_consistent(dump_mode, rank_num):
109
109
  return []
110
110
 
111
111
  dump_mode_compare_index_map = {
112
- Const.ALL: CompareConst.ALL_COMPARE_INDEX,
113
- Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX
112
+ Const.ALL: CompareConst.ALL_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST],
113
+ Const.SUMMARY: CompareConst.SUMMARY_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
114
114
  }
115
115
  valid_compare_index = dump_mode_compare_index_map.get(dump_mode)
116
116