mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
  3. msprobe/README.md +27 -22
  4. msprobe/core/common/const.py +129 -60
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/inplace_ops.yaml +1 -0
  9. msprobe/core/common/utils.py +43 -33
  10. msprobe/core/compare/acc_compare.py +43 -74
  11. msprobe/core/compare/check.py +2 -6
  12. msprobe/core/compare/highlight.py +2 -0
  13. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  14. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  15. msprobe/core/compare/merge_result/merge_result.py +16 -9
  16. msprobe/core/compare/merge_result/utils.py +81 -0
  17. msprobe/core/compare/multiprocessing_compute.py +19 -12
  18. msprobe/core/compare/npy_compare.py +30 -12
  19. msprobe/core/compare/utils.py +30 -10
  20. msprobe/core/data_dump/api_registry.py +176 -0
  21. msprobe/core/data_dump/data_collector.py +58 -13
  22. msprobe/core/data_dump/data_processor/base.py +94 -10
  23. msprobe/core/data_dump/data_processor/factory.py +3 -0
  24. msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
  25. msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
  26. msprobe/core/data_dump/json_writer.py +61 -40
  27. msprobe/core/grad_probe/constant.py +1 -0
  28. msprobe/core/grad_probe/grad_compare.py +1 -1
  29. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  30. msprobe/docs/01.installation.md +27 -1
  31. msprobe/docs/02.config_introduction.md +27 -23
  32. msprobe/docs/03.config_examples.md +24 -0
  33. msprobe/docs/05.data_dump_PyTorch.md +103 -16
  34. msprobe/docs/06.data_dump_MindSpore.md +76 -32
  35. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  36. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  37. msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
  38. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  39. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  40. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  41. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  42. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  43. msprobe/docs/18.online_dispatch.md +1 -1
  44. msprobe/docs/19.monitor.md +332 -273
  45. msprobe/docs/21.visualization_PyTorch.md +42 -13
  46. msprobe/docs/22.visualization_MindSpore.md +43 -13
  47. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  48. msprobe/docs/27.dump_json_instruction.md +301 -27
  49. msprobe/docs/28.debugger_save_instruction.md +94 -0
  50. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  51. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  52. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  53. msprobe/docs/FAQ.md +3 -11
  54. msprobe/docs/img/compare_result.png +0 -0
  55. msprobe/docs/img/merge_result.png +0 -0
  56. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  57. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  58. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  59. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  60. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  61. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  63. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  64. msprobe/mindspore/__init__.py +4 -2
  65. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
  66. msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
  67. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  68. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  69. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  70. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  71. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  72. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  73. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
  74. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  75. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  76. msprobe/mindspore/common/const.py +61 -0
  77. msprobe/mindspore/common/utils.py +48 -18
  78. msprobe/mindspore/compare/ms_compare.py +27 -19
  79. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  80. msprobe/mindspore/debugger/debugger_config.py +31 -6
  81. msprobe/mindspore/debugger/precision_debugger.py +45 -14
  82. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  83. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  84. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  85. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  86. msprobe/mindspore/dump/jit_dump.py +21 -15
  87. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  88. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  89. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  90. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  91. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  92. msprobe/mindspore/grad_probe/global_context.py +2 -0
  93. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  94. msprobe/mindspore/grad_probe/hook.py +2 -4
  95. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  96. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  97. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  98. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  99. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  100. msprobe/mindspore/monitor/features.py +63 -0
  101. msprobe/mindspore/monitor/module_hook.py +873 -0
  102. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  103. msprobe/mindspore/monitor/utils.py +309 -0
  104. msprobe/mindspore/ms_config.py +8 -2
  105. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  106. msprobe/mindspore/service.py +114 -34
  107. msprobe/pytorch/__init__.py +0 -1
  108. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  109. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
  110. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  111. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  112. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  116. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  117. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  118. msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
  119. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
  120. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  121. msprobe/pytorch/common/utils.py +97 -4
  122. msprobe/pytorch/debugger/debugger_config.py +19 -9
  123. msprobe/pytorch/debugger/precision_debugger.py +24 -1
  124. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  125. msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
  126. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  127. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  132. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  133. msprobe/pytorch/function_factory.py +8 -2
  134. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  135. msprobe/pytorch/hook_module/api_register.py +131 -0
  136. msprobe/pytorch/hook_module/hook_module.py +19 -14
  137. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  138. msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
  139. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  140. msprobe/pytorch/monitor/csv2tb.py +18 -14
  141. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  142. msprobe/pytorch/monitor/module_hook.py +238 -193
  143. msprobe/pytorch/monitor/module_metric.py +9 -6
  144. msprobe/pytorch/monitor/optimizer_collect.py +100 -67
  145. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  146. msprobe/pytorch/monitor/utils.py +76 -44
  147. msprobe/pytorch/online_dispatch/compare.py +0 -2
  148. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  149. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  150. msprobe/pytorch/online_dispatch/utils.py +3 -0
  151. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  152. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  153. msprobe/pytorch/pt_config.py +30 -29
  154. msprobe/pytorch/service.py +114 -32
  155. msprobe/visualization/builder/graph_builder.py +75 -10
  156. msprobe/visualization/builder/msprobe_adapter.py +7 -6
  157. msprobe/visualization/compare/graph_comparator.py +42 -38
  158. msprobe/visualization/compare/mode_adapter.py +0 -19
  159. msprobe/visualization/graph/base_node.py +11 -3
  160. msprobe/visualization/graph/distributed_analyzer.py +71 -3
  161. msprobe/visualization/graph/graph.py +0 -11
  162. msprobe/visualization/graph/node_op.py +4 -3
  163. msprobe/visualization/graph_service.py +4 -5
  164. msprobe/visualization/utils.py +12 -35
  165. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
  166. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  167. msprobe/pytorch/hook_module/api_registry.py +0 -166
  168. msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
  169. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  171. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  172. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  173. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  174. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  175. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  176. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  177. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,14 @@ import re
17
17
  from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
18
18
  from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
19
19
  from msprobe.visualization.graph.graph import Graph, NodeOp
20
- from msprobe.visualization.graph.node_colors import NodeColors
21
20
  from msprobe.visualization.compare.mode_adapter import ModeAdapter
22
21
  from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
23
 
24
24
 
25
25
  class GraphComparator:
26
+ MAX_DEPTH = 1000
27
+
26
28
  def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
27
29
  self.graph_n = graphs[0]
28
30
  self.graph_b = graphs[1]
@@ -41,7 +43,7 @@ class GraphComparator:
41
43
  else:
42
44
  self._compare_nodes(self.graph_n.root)
43
45
  self._postcompare()
44
-
46
+
45
47
  def add_compare_result_to_node(self, node, compare_result_list):
46
48
  """
47
49
  将比对结果添加到节点的输入输出数据中
@@ -66,43 +68,8 @@ class GraphComparator:
66
68
  self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
67
69
  node.data[GraphConst.JSON_INDEX_KEY] = precision_index
68
70
  node.data.update(other_dict)
69
-
70
- def _parse_param(self, dump_path_param, output_path):
71
- self.dump_path_param = dump_path_param
72
- self.output_path = output_path
73
- compare_mode = get_compare_mode(self.dump_path_param)
74
- self.ma = ModeAdapter(compare_mode)
75
- self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
76
- self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
77
- self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
78
-
79
- def _postcompare(self):
80
- self._handle_api_collection_index()
81
- if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
82
- return
83
- df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
84
- df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
85
- compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
86
- for node in self.ma.compare_nodes:
87
- precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
88
- node.data[GraphConst.JSON_INDEX_KEY] = precision_index
89
-
90
- def _handle_api_collection_index(self):
91
- """
92
- api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
93
- md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
94
- """
95
- for node in self.graph_n.root.subnodes:
96
- if node.op == NodeOp.api_collection:
97
- precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
98
- else GraphConst.MIN_INDEX_KEY
99
- for api in node.subnodes:
100
- precision_index = min(precision_index,
101
- api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
102
- if self.ma.compare_mode == GraphConst.MD5_COMPARE \
103
- else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
104
- node.data[GraphConst.JSON_INDEX_KEY] = precision_index
105
71
 
72
+ @recursion_depth_decorator('GraphComparator._compare_nodes', max_depth=MAX_DEPTH)
106
73
  def _compare_nodes(self, node_n):
107
74
  """
108
75
  递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
@@ -126,6 +93,7 @@ class GraphComparator:
126
93
  for subnode in node_n.subnodes:
127
94
  self._compare_nodes(subnode)
128
95
 
96
+ @recursion_depth_decorator('GraphComparator._compare_nodes_fuzzy', max_depth=MAX_DEPTH)
129
97
  def _compare_nodes_fuzzy(self, node_n):
130
98
  if node_n.op != NodeOp.function_api:
131
99
  # 模块经过模糊匹配
@@ -146,6 +114,42 @@ class GraphComparator:
146
114
  for sub_node in node_n.subnodes:
147
115
  self._compare_nodes_fuzzy(sub_node)
148
116
 
117
+ def _parse_param(self, dump_path_param, output_path):
118
+ self.dump_path_param = dump_path_param
119
+ self.output_path = output_path
120
+ compare_mode = get_compare_mode(self.dump_path_param)
121
+ self.ma = ModeAdapter(compare_mode)
122
+ self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
123
+ self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
124
+ self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
125
+
126
+ def _postcompare(self):
127
+ self._handle_api_collection_index()
128
+ if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
129
+ return
130
+ df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
131
+ df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
132
+ compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
133
+ for node in self.ma.compare_nodes:
134
+ precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
135
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
136
+
137
+ def _handle_api_collection_index(self):
138
+ """
139
+ api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
140
+ md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
141
+ """
142
+ for node in self.graph_n.root.subnodes:
143
+ if node.op == NodeOp.api_collection:
144
+ precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
145
+ else GraphConst.MIN_INDEX_KEY
146
+ for api in node.subnodes:
147
+ precision_index = min(precision_index,
148
+ api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
149
+ if self.ma.compare_mode == GraphConst.MD5_COMPARE \
150
+ else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
151
+ node.data[GraphConst.JSON_INDEX_KEY] = precision_index
152
+
149
153
  def _get_and_add_result(self, node_n, node_b):
150
154
  compare_result_list = compare_node([node_n.id, node_b.id],
151
155
  [self.data_n_dict, self.data_b_dict],
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import json
17
- import math
18
17
  from msprobe.core.common.const import CompareConst, Const
19
18
  from msprobe.visualization.utils import ToolTip, GraphConst, str2float
20
19
 
@@ -157,24 +156,6 @@ class ModeAdapter:
157
156
  return
158
157
  self.csv_data.extend(compare_result_list)
159
158
 
160
- def add_error_key(self, node_data):
161
- """
162
- 根据不同的模式进行提供不同错误信息
163
- """
164
- for key, value in node_data.items():
165
- if not isinstance(value, dict):
166
- continue
167
- if self.compare_mode == GraphConst.SUMMARY_COMPARE:
168
- message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
169
- CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
170
- elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
171
- message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
172
- else:
173
- # 输出件优化
174
- message = []
175
- value[GraphConst.ERROR_KEY] = message
176
- node_data[key] = value
177
-
178
159
  def get_tool_tip(self):
179
160
  """
180
161
  用于前端展示字段的具体含义
@@ -12,10 +12,11 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  from msprobe.core.overflow_check.level import OverflowLevel
16
- from msprobe.visualization.graph.node_op import NodeOp
17
17
  from msprobe.visualization.utils import GraphConst
18
18
  from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
19
+ from msprobe.core.common.log import logger
19
20
 
20
21
 
21
22
  class BaseNode:
@@ -34,6 +35,7 @@ class BaseNode:
34
35
  self.micro_step_id = None
35
36
  self.overflow_level = None
36
37
  self.matched_distributed = {}
38
+ self.batch_p2p_info = []
37
39
 
38
40
  def __str__(self):
39
41
  info = f'id:\t{self.id}'
@@ -92,8 +94,8 @@ class BaseNode:
92
94
  result = {
93
95
  'id': self.id,
94
96
  'node_type': self.op.value,
95
- 'output_data': format_node_data(self.output_data),
96
- 'input_data': format_node_data(self.input_data),
97
+ 'output_data': format_node_data(self.output_data, self.id),
98
+ 'input_data': format_node_data(self.input_data, self.id),
97
99
  'upnode': self.upnode.id if self.upnode else 'None',
98
100
  'subnodes': [node.id for node in self.subnodes],
99
101
  'matched_node_link': self.matched_node_link,
@@ -113,7 +115,13 @@ class BaseNode:
113
115
  """
114
116
  ancestors = []
115
117
  current_node = self.upnode
118
+ seen_nodes = set()
116
119
  while current_node:
120
+ if current_node.id in seen_nodes:
121
+ logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, '
122
+ f'current node is {current_node.id}.')
123
+ return []
124
+ seen_nodes.add(current_node.id)
117
125
  ancestors.append(current_node.id)
118
126
  current_node = current_node.upnode
119
127
  return list(reversed(ancestors))
@@ -115,7 +115,9 @@ class DistributedAnalyzer:
115
115
  if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed:
116
116
  continue
117
117
  api_name, distributed_type = self._get_distributed_name_and_type(node_id)
118
- if distributed_type == DistributedType.P2P:
118
+ if api_name == GraphConst.BATCH_P2P:
119
+ self._batch_p2p_match(node, rank)
120
+ elif distributed_type == DistributedType.P2P:
119
121
  self._p2p_match(node, rank, api_name)
120
122
  else:
121
123
  self._collective_match(node, rank, api_name)
@@ -138,12 +140,16 @@ class DistributedAnalyzer:
138
140
  for rank, graph in self.graphs.items():
139
141
  group_count = {}
140
142
  group_info = {}
143
+ batch_p2p_count = {}
141
144
  nodes = graph.node_map
142
145
  for node_id, node in nodes.items():
143
146
  if not node_id.startswith(Const.DISTRIBUTED):
144
147
  continue
145
148
  api_name, distributed_type = self._get_distributed_name_and_type(node_id)
146
- if distributed_type == DistributedType.P2P:
149
+ if api_name == GraphConst.BATCH_P2P:
150
+ self._make_batch_p2p_mapping(node, rank, batch_p2p_count)
151
+ continue
152
+ elif distributed_type == DistributedType.P2P:
147
153
  config_info = self.config.get(api_name)
148
154
  target_rank = self._get_target_rank(node, rank, config_info[1])
149
155
  if target_rank is None:
@@ -162,7 +168,32 @@ class DistributedAnalyzer:
162
168
  unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id))
163
169
  group_info[unique_group_id] = node_id
164
170
  group_info[node_id] = unique_group_id
165
- self.group_node_mapping[rank] = group_info
171
+ if rank not in self.group_node_mapping:
172
+ self.group_node_mapping[rank] = {}
173
+ self.group_node_mapping[rank].update(group_info)
174
+
175
+ def _make_batch_p2p_mapping(self, node, rank, batch_p2p_count):
176
+ """
177
+ 给batch_isend_irecv接口的每个p2p内容赋予唯一标识
178
+ """
179
+ if rank not in self.group_node_mapping:
180
+ self.group_node_mapping[rank] = {}
181
+ params = []
182
+ for info_dict in node.batch_p2p_info:
183
+ op = info_dict.get(GraphConst.OP)
184
+ target_rank = info_dict.get(GraphConst.PEER)
185
+ if op is None or target_rank is None:
186
+ logger.warning('Cannot get param op or peer.')
187
+ continue
188
+ group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \
189
+ Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '')
190
+ batch_p2p_count[group_id] = batch_p2p_count.get(group_id, 0) + 1
191
+ # 例如: isend_rank0_5a4d31ad765260ba50eb190f1f9fd163_1
192
+ unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(batch_p2p_count.get(group_id))
193
+ params.append(unique_group_id)
194
+ self.group_node_mapping.get(rank)[unique_group_id] = node.id
195
+ if params:
196
+ self.group_node_mapping.get(rank)[node.id] = params
166
197
 
167
198
  def _get_distributed_name_and_type(self, node_id):
168
199
  if Const.SEP not in node_id:
@@ -316,3 +347,40 @@ class DistributedAnalyzer:
316
347
  if nodes_info:
317
348
  matched_distributed['nodes_info'] = nodes_info
318
349
  node.matched_distributed = matched_distributed
350
+
351
+ def _batch_p2p_match(self, node, rank):
352
+ """
353
+ 批量点对点匹配
354
+
355
+ 针对torch.distributed.batch_isend_irecv接口,其入参是一个包含点对点通信信息的集合,需要遍历集合对每个点对点通信信息进行匹配
356
+ :param node: 当前集体通信节点
357
+ :param rank: 当前节点所属rank
358
+ :return:
359
+ """
360
+ unique_group_ids = self.group_node_mapping.get(rank, {}).get(node.id)
361
+ if not unique_group_ids:
362
+ return
363
+ matched_distributed = [] if len(unique_group_ids) > 1 else {}
364
+ for unique_group_id in unique_group_ids:
365
+ try:
366
+ id_info = unique_group_id.split(Const.REPLACEMENT_CHARACTER)
367
+ api_name = id_info[0]
368
+ target_api_name = self.config.get(api_name)[0]
369
+ target_rank = int(id_info[1].replace(Const.RANK, ''))
370
+ except Exception as e:
371
+ logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
372
+ continue
373
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
374
+ if not target_node:
375
+ continue
376
+ communications_type = self.config.get(api_name)[2]
377
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
378
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
379
+ matched_info = {
380
+ 'communications_type': communications_type,
381
+ 'nodes_info': {target_rank: [str(index), target_node.id]}
382
+ }
383
+ matched_distributed.append(matched_info) if isinstance(matched_distributed, list) \
384
+ else matched_distributed.update(matched_info)
385
+ if matched_distributed:
386
+ node.matched_distributed = matched_distributed
@@ -20,9 +20,6 @@ from msprobe.core.common.log import logger
20
20
  from msprobe.core.common.const import Const
21
21
 
22
22
 
23
- MAX_RECUR_LEVEL = 100
24
-
25
-
26
23
  class Graph:
27
24
  def __init__(self, model_name, data_path='', dump_data=None):
28
25
  self.node_map = {}
@@ -67,7 +64,6 @@ class Graph:
67
64
  ancestors_b = node_b.get_ancestors()
68
65
  return node_b, ancestors_n, ancestors_b
69
66
 
70
-
71
67
  @staticmethod
72
68
  def fuzzy_match(node_n, node_b):
73
69
  if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
@@ -76,13 +72,6 @@ class Graph:
76
72
  ancestors_b = node_b.get_ancestors()
77
73
  return node_b, ancestors_n, ancestors_b
78
74
 
79
- @staticmethod
80
- def dfs(node, result):
81
- info = node.to_dict()
82
- result[node.id] = info
83
- for subnode in node.subnodes:
84
- Graph.dfs(subnode, result)
85
-
86
75
  @staticmethod
87
76
  def split_nodes_by_micro_step(nodes):
88
77
  """
@@ -16,6 +16,7 @@
16
16
  from enum import Enum
17
17
  import re
18
18
  from msprobe.visualization.builder.msprobe_adapter import op_patterns
19
+ from msprobe.core.common.log import logger
19
20
 
20
21
 
21
22
  class NodeOp(Enum):
@@ -23,7 +24,6 @@ class NodeOp(Enum):
23
24
  function_api = 1
24
25
  api_collection = 9
25
26
 
26
-
27
27
  @staticmethod
28
28
  def get_node_op(node_name: str):
29
29
  """
@@ -32,8 +32,9 @@ class NodeOp(Enum):
32
32
  for op in NodeOp:
33
33
  index = op.value
34
34
  if index < 0 or index >= len(op_patterns):
35
- raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match")
35
+ continue
36
36
  pattern = op_patterns[index]
37
37
  if re.match(pattern, node_name):
38
38
  return op
39
- raise Exception(f"Cannot parse node_name {node_name} into NodeOp")
39
+ logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.")
40
+ return NodeOp.module
@@ -16,8 +16,8 @@
16
16
  import os
17
17
  import time
18
18
  import json
19
- from msprobe.core.common.file_utils import (FileOpen, check_file_type, create_directory, FileChecker,
20
- check_file_or_directory_path)
19
+ from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
20
+ check_file_or_directory_path, load_json)
21
21
  from msprobe.core.common.const import FileCheckConst, Const
22
22
  from msprobe.core.common.utils import CompareException
23
23
  from msprobe.core.overflow_check.checker import AnomalyDetector
@@ -159,7 +159,7 @@ def _compare_graph_steps(input_param, args):
159
159
  bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
160
160
 
161
161
  if npu_steps != bench_steps:
162
- logger.error('The number of steps in the two runs are different. Unable to match the steps.')
162
+ logger.error('The number of steps in the two runs is different. Unable to match the steps.')
163
163
  raise CompareException(CompareException.INVALID_PATH_ERROR)
164
164
 
165
165
  for folder_step in npu_steps:
@@ -220,8 +220,7 @@ def _graph_service_parser(parser):
220
220
 
221
221
 
222
222
  def _graph_service_command(args):
223
- with FileOpen(args.input_path, "r") as file:
224
- input_param = json.load(file)
223
+ input_param = load_json(args.input_path)
225
224
  npu_path = input_param.get("npu_path")
226
225
  bench_path = input_param.get("bench_path")
227
226
  check_file_or_directory_path(npu_path, isdir=True)
@@ -42,14 +42,6 @@ def load_data_json_file(file_path):
42
42
  return load_json_file(file_path).get(GraphConst.DATA_KEY, {})
43
43
 
44
44
 
45
- def save_json_file(file_path, data):
46
- """
47
- 保存json文件
48
- """
49
- with FileOpen(file_path, 'w') as f:
50
- f.write(json.dumps(data, indent=4))
51
-
52
-
53
45
  def get_csv_df(stack_mode, csv_data, compare_mode):
54
46
  """
55
47
  调用acc接口写入csv
@@ -73,14 +65,6 @@ def str2float(percentage_str):
73
65
  return 0
74
66
 
75
67
 
76
- def is_integer(s):
77
- try:
78
- int(s)
79
- return True
80
- except Exception:
81
- return False
82
-
83
-
84
68
  def check_directory_content(input_path):
85
69
  """
86
70
  检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件
@@ -143,18 +127,17 @@ class ToolTip:
143
127
  '当最大相对误差越接近0表示其计算的误差越小。'
144
128
  '当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象'
145
129
  )
146
- SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差'
147
130
 
148
131
 
149
132
  class GraphConst:
150
133
  CONSTRUCT_FILE = 'construct.json'
151
134
  DUMP_FILE = 'dump.json'
152
135
  STACK_FILE = 'stack.json'
153
- GRAPH_FILE = 'graph.vis'
154
136
  ERROR_KEY = 'error_key'
155
137
  SUMMARY_COMPARE = 0
156
138
  MD5_COMPARE = 1
157
139
  REAL_DATA_COMPARE = 2
140
+ STRUCTURE_COMPARE = 3
158
141
  JSON_NPU_KEY = 'NPU'
159
142
  JSON_BENCH_KEY = 'Bench'
160
143
  JSON_TIP_KEY = 'ToolTip'
@@ -163,35 +146,22 @@ class GraphConst:
163
146
  JSON_DATA_KEY = 'dump_data_dir'
164
147
  JSON_TASK_KEY = 'task'
165
148
  DATA_KEY = 'data'
166
- REAL_DATA_TH = 0.1
167
- MAX_RELATIVE_ERR_TH = 0.5
168
149
  ROUND_TH = 6
169
150
  JSON_INDEX_KEY = 'precision_index'
170
151
  MATCHED_DISTRIBUTED = 'matched_distributed'
171
152
  OVERFLOW_LEVEL = 'overflow_level'
172
153
  MAX_INDEX_KEY = 1
173
154
  MIN_INDEX_KEY = 0
174
- SUGGEST_KEY = 'text'
175
- TAG_NA = 'na'
176
- OUTPUT_INDEX_TWO = -2
177
- OUTPUT_INDEX_THREE = -3
178
- OUTPUT_MIN_LEN = 3
179
155
  INPUT = '.input.'
180
156
  OUTPUT = '.output.'
181
157
  STR_MAX_LEN = 50
182
- SMALL_VALUE = 1e-3
183
158
  MD5_INDEX_LIST = [CompareConst.RESULT]
184
- REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
185
- CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
186
- SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF,
187
- CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
188
- CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
189
- VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
159
+ REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX
160
+ SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX
190
161
  APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
191
162
  NULL = 'null'
192
163
  NONE = 'None'
193
164
  VALUE = 'value'
194
- BRACE = '{}'
195
165
  DESCRIPTION = 'description'
196
166
  COLORS = 'Colors'
197
167
  MICRO_STEPS = 'MicroSteps'
@@ -200,13 +170,15 @@ class GraphConst:
200
170
  DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING = {
201
171
  Const.ALL: REAL_DATA_COMPARE,
202
172
  Const.SUMMARY: SUMMARY_COMPARE,
203
- Const.MD5: MD5_COMPARE
173
+ Const.MD5: MD5_COMPARE,
174
+ Const.STRUCTURE: STRUCTURE_COMPARE
204
175
  }
205
176
 
206
177
  GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING = {
207
178
  REAL_DATA_COMPARE: Const.ALL,
208
179
  SUMMARY_COMPARE: Const.SUMMARY,
209
- MD5_COMPARE: Const.MD5
180
+ MD5_COMPARE: Const.MD5,
181
+ STRUCTURE_COMPARE: Const.STRUCTURE
210
182
  }
211
183
 
212
184
  RANKS = 'ranks'
@@ -215,3 +187,8 @@ class GraphConst:
215
187
 
216
188
  SRC = 'src'
217
189
  DST = 'dst'
190
+
191
+ BATCH_P2P = 'batch_isend_irecv'
192
+ OP = 'op'
193
+ PEER = 'peer'
194
+ GROUP_ID = 'group_id'