mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -28,7 +28,7 @@ op_patterns = [
28
28
  # NodeOp.module
29
29
  r'^(Module.|Cell.|optimizer|clip_grad)',
30
30
  # NodeOp.function_api
31
- r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
31
+ r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.|MindSpeed.)'
32
32
  ]
33
33
 
34
34
 
@@ -54,7 +54,13 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
54
54
  framework: 框架类型, pytorch或mindspore
55
55
  is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
56
56
  """
57
- mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
57
+ config_dict = {
58
+ 'stack_mode': False,
59
+ 'auto_analyze': True,
60
+ 'fuzzy_match': False,
61
+ 'dump_mode': Const.ALL
62
+ }
63
+ mode_config = ModeConfig(**config_dict)
58
64
 
59
65
  if framework == Const.PT_FRAMEWORK:
60
66
  from msprobe.pytorch.compare.pt_compare import read_real_data
@@ -125,7 +131,7 @@ def format_node_data(data_dict, node_id=None, compare_mode=None):
125
131
  """
126
132
  删除节点数据中不需要展示的字段
127
133
  """
128
- del_list = ['requires_grad', 'full_op_name']
134
+ del_list = ['state', 'full_op_name']
129
135
  if GraphConst.MD5_COMPARE != compare_mode:
130
136
  del_list.append(Const.MD5)
131
137
  if node_id and GraphConst.BATCH_P2P in node_id:
@@ -140,31 +146,27 @@ def format_node_data(data_dict, node_id=None, compare_mode=None):
140
146
  return data_dict
141
147
 
142
148
 
143
- def compare_node(node_ids, data_dicts, stack_json_data, compare_mode):
149
+ def compare_node(node_n, node_b, compare_mode):
144
150
  """
145
151
  调用acc_compare.py中的get_accuracy获得精度对比指标
146
152
  真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口
147
153
  Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list
148
154
  """
149
- merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, compare_mode)
150
- merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, compare_mode)
151
- result = []
152
155
  dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
156
+ merge_n = _parse_node(node_n, dump_mode)
157
+ merge_b = _parse_node(node_b, dump_mode)
158
+ result = []
153
159
  get_accuracy(result, merge_n, merge_b, dump_mode)
154
160
  return result
155
161
 
156
162
 
157
- def _parse_node(node_id, data_dict, stack_json_data, compare_mode):
163
+ def _parse_node(node, dump_mode):
158
164
  """
159
165
  转换节点,使其能够作为acc_compare.py中的get_accuracy的入参
160
166
  """
161
- dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
162
- op_parsed_list = read_op(data_dict.get(node_id, {}), node_id)
163
- if node_id in stack_json_data:
164
- op_parsed_list.append(
165
- {'full_op_name': node_id, 'full_info': stack_json_data[node_id]})
166
- else:
167
- op_parsed_list.append({'full_op_name': node_id, 'full_info': None})
167
+ op_parsed_list = []
168
+ op_parsed_list.extend(node.input_data.values())
169
+ op_parsed_list.extend(node.output_data.values())
168
170
  result = merge_tensor(op_parsed_list, dump_mode)
169
171
  if not result:
170
172
  result['op_name'] = []
@@ -35,13 +35,15 @@ class GraphComparator:
35
35
  self.fuzzy_match = args.fuzzy_match
36
36
  self.pattern = re.compile(r'\.\d+\.')
37
37
  self.is_cross_framework = is_cross_framework
38
+ self.parallel_merge = args.parallel_merge if hasattr(args, 'parallel_merge') else False
39
+ self.rank_pattern = re.compile(r"_rank\d+")
38
40
 
39
41
  def compare(self):
40
42
  """
41
43
  比较函数,初始化结束后单独调用。比较结果写入graph_n
42
44
  """
43
45
  if self.fuzzy_match:
44
- self._compare_nodes_fuzzy(self.graph_n.root)
46
+ self._compare_nodes_fuzzy(self.graph_n.root, False if self.parallel_merge else True)
45
47
  else:
46
48
  self._compare_nodes(self.graph_n.root)
47
49
  self._postcompare()
@@ -98,11 +100,12 @@ class GraphComparator:
98
100
  while node_list:
99
101
  compare_single_node(node_list.pop(0))
100
102
 
101
- def _compare_nodes_fuzzy(self, node_root):
103
+ def _compare_nodes_fuzzy(self, node_root, check_shape=True):
102
104
  def compare_single_nodes_fuzzy(node_n):
103
105
  if node_n.op != NodeOp.function_api:
104
106
  # 模块经过模糊匹配
105
- node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
107
+ node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id),
108
+ check_shape)
106
109
  if node_b:
107
110
  self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
108
111
  # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
@@ -113,7 +116,7 @@ class GraphComparator:
113
116
  if not api_node_n:
114
117
  continue
115
118
  api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
116
- api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
119
+ api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)), check_shape)
117
120
  if api_node_b:
118
121
  self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
119
122
  node_list.extend(node_n.subnodes)
@@ -147,21 +150,26 @@ class GraphComparator:
147
150
  api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
148
151
  md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
149
152
  """
153
+ def handle_api_collection_index(api_collection_node):
154
+ precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
155
+ else GraphConst.MIN_INDEX_KEY
156
+ for api in api_collection_node.subnodes:
157
+ precision_index = min(precision_index,
158
+ api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
159
+ if self.ma.compare_mode == GraphConst.MD5_COMPARE \
160
+ else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
161
+ api_collection_node.data[GraphConst.JSON_INDEX_KEY] = precision_index
162
+
150
163
  for node in self.graph_n.root.subnodes:
151
- if node.op == NodeOp.api_collection:
152
- precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
153
- else GraphConst.MIN_INDEX_KEY
154
- for api in node.subnodes:
155
- precision_index = min(precision_index,
156
- api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
157
- if self.ma.compare_mode == GraphConst.MD5_COMPARE \
158
- else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
159
- node.data[GraphConst.JSON_INDEX_KEY] = precision_index
164
+ if node.op == NodeOp.api_collection and node.id.startswith(GraphConst.APIS_BETWEEN_MODULES_ALL_RANKS):
165
+ for sub_node in node.subnodes:
166
+ handle_api_collection_index(sub_node)
167
+ handle_api_collection_index(node)
168
+ elif node.op == NodeOp.api_collection:
169
+ handle_api_collection_index(node)
160
170
 
161
171
  def _get_and_add_result(self, node_n, node_b):
162
- compare_result_list = compare_node([node_n.id, node_b.id],
163
- [self.data_n_dict, self.data_b_dict],
164
- self.stack_json_data, self.ma.compare_mode)
172
+ compare_result_list = compare_node(node_n, node_b, self.ma.compare_mode)
165
173
  if compare_result_list:
166
174
  self.ma.add_csv_data(compare_result_list)
167
175
  self.add_compare_result_to_node(node_n, compare_result_list)
@@ -178,6 +186,8 @@ class GraphComparator:
178
186
  if sub_node.op == NodeOp.function_api:
179
187
  # 忽略dump调用次数
180
188
  count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
189
+ if self.rank_pattern.search(count_removed_id):
190
+ count_removed_id = self.rank_pattern.sub('', count_removed_id)
181
191
  node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
182
192
  # 赋予模块中的调用顺序
183
193
  recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
@@ -0,0 +1,252 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import sqlite3
18
+ import json
19
+ import re
20
+ from msprobe.core.common.log import logger
21
+ from msprobe.core.common.file_utils import change_mode, check_path_before_create, FileChecker
22
+ from msprobe.core.common.const import FileCheckConst
23
+ from msprobe.visualization.utils import GraphConst
24
+ from msprobe.visualization.builder.msprobe_adapter import format_node_data
25
+
26
+ TEXT_PRIMARY_KEY = 'TEXT PRIMARY KEY'
27
+ TEXT_NOT_NULL = 'TEXT NOT NULL'
28
+ INTEGER_NOT_NULL = 'INTEGER NOT NULL'
29
+ TEXT = 'TEXT'
30
+ INTEGER = 'INTEGER'
31
+
32
+ node_columns = {
33
+ 'id': TEXT_PRIMARY_KEY,
34
+ 'graph_id': TEXT_NOT_NULL,
35
+ 'node_order': INTEGER_NOT_NULL,
36
+ 'node_name': TEXT_NOT_NULL,
37
+ 'node_type': TEXT_NOT_NULL,
38
+ 'up_node': TEXT,
39
+ 'sub_nodes': TEXT,
40
+ 'precision_index': INTEGER,
41
+ 'overflow_level': TEXT,
42
+ 'micro_step_id': INTEGER_NOT_NULL,
43
+ 'matched_node_link': TEXT,
44
+ 'stack_id': TEXT,
45
+ 'parallel_merge_info': TEXT,
46
+ 'matched_distributed': TEXT,
47
+ 'modified': INTEGER_NOT_NULL,
48
+ 'input_data': TEXT,
49
+ 'output_data': TEXT,
50
+ 'data_source': TEXT,
51
+ 'dump_data_dir': TEXT,
52
+ 'step': INTEGER_NOT_NULL,
53
+ 'rank': INTEGER_NOT_NULL
54
+ }
55
+
56
+ config_columns = {
57
+ 'id': TEXT_PRIMARY_KEY,
58
+ 'graph_type': TEXT_NOT_NULL,
59
+ 'task': TEXT,
60
+ 'tool_tip': TEXT,
61
+ 'micro_steps': INTEGER,
62
+ 'overflow_check': INTEGER,
63
+ 'node_colors': TEXT_NOT_NULL,
64
+ 'rank_list': TEXT_NOT_NULL,
65
+ 'step_list': TEXT_NOT_NULL
66
+ }
67
+
68
+ stack_columns = {
69
+ 'id': TEXT_PRIMARY_KEY,
70
+ 'stack_info': TEXT
71
+ }
72
+
73
+ indexes = {
74
+ "index1": ["step", "rank", "data_source", "up_node", "node_order"],
75
+ "index2": ["step", "rank", "data_source", "node_name"],
76
+ "index3": ["step", "rank", "data_source", "node_order"],
77
+ "index4": ["step", "rank", "node_order"],
78
+ "index5": ["step", "rank", "micro_step_id", "node_order"],
79
+ "index6": ["step", "rank", "modified", "matched_node_link"]
80
+ }
81
+
82
+ SAFE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$')
83
+
84
+
85
+ def is_safe_identifier(name):
86
+ """验证标识符是否安全(防止SQL注入)"""
87
+ return isinstance(name, str) and SAFE_NAME_PATTERN.match(name) is not None
88
+
89
+
90
+ def create_table_sql_from_dict(table_name, columns_dict):
91
+ """
92
+ 根据提供的表名和列定义字典生成CREATE TABLE SQL语句。
93
+ """
94
+ if not is_safe_identifier(table_name):
95
+ raise ValueError(f"Invalid table name: {table_name} - potential SQL injection risk!")
96
+
97
+ sql = f"CREATE TABLE IF NOT EXISTS {table_name} (\n"
98
+
99
+ column_definitions = []
100
+ for column_name, column_type in columns_dict.items():
101
+ if not is_safe_identifier(column_name):
102
+ raise ValueError(f"Invalid column name: {column_name} - potential SQL injection risk!")
103
+
104
+ column_definitions.append(f" {column_name} {column_type}")
105
+
106
+ sql += ",\n".join(column_definitions)
107
+ sql += "\n);"
108
+
109
+ return sql
110
+
111
+
112
+ def create_insert_sql_from_dict(table_name, columns_dict, ignore_insert=False):
113
+ """
114
+ 根据提供的表名和数据字典生成INSERT INTO SQL语句。
115
+ """
116
+ if not is_safe_identifier(table_name):
117
+ raise ValueError(f"Invalid table name: {table_name} - potential SQL injection risk!")
118
+
119
+ columns = list(columns_dict.keys())
120
+
121
+ for column_name in columns:
122
+ if not is_safe_identifier(column_name):
123
+ raise ValueError(f"Invalid column name: {column_name} - potential SQL injection risk!")
124
+
125
+ placeholders = ["?"] * len(columns)
126
+
127
+ columns_string = ", ".join(columns)
128
+ placeholders_string = ", ".join(placeholders)
129
+
130
+ sql_prefix = "INSERT OR IGNORE INTO" if ignore_insert else "INSERT INTO"
131
+ sql = f"{sql_prefix} {table_name} ({columns_string}) VALUES ({placeholders_string})"
132
+ return sql
133
+
134
+
135
+ def to_db(db_path, create_table_sql, insert_sql, data, db_insert_size=1000):
136
+ if not os.path.exists(db_path):
137
+ check_path_before_create(db_path)
138
+ else:
139
+ FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE,
140
+ FileCheckConst.DB_SUFFIX).common_check()
141
+ try:
142
+ conn = sqlite3.connect(db_path)
143
+ except sqlite3.Error as e:
144
+ logger.error(f"Unable to create database connection: {e}")
145
+ raise RuntimeError("Unable to create database connection") from e
146
+
147
+ try:
148
+ cursor = conn.cursor()
149
+ cursor.execute(create_table_sql)
150
+ if len(data) == 1:
151
+ cursor.execute(insert_sql, data[0])
152
+ conn.commit()
153
+ else:
154
+ for i in range(0, len(data), db_insert_size):
155
+ batch = data[i:i + db_insert_size]
156
+ cursor.executemany(insert_sql, batch)
157
+ conn.commit()
158
+ except sqlite3.Error as e:
159
+ logger.error(f"An sqlite3 error occurred: {e}")
160
+ raise RuntimeError("An sqlite3 error occurred") from e
161
+ finally:
162
+ conn.close()
163
+
164
+
165
+ def add_table_index(db_path):
166
+ FileChecker(db_path, FileCheckConst.FILE, FileCheckConst.READ_WRITE_ABLE, FileCheckConst.DB_SUFFIX).common_check()
167
+ try:
168
+ conn = sqlite3.connect(db_path)
169
+ except sqlite3.Error as e:
170
+ logger.error(f"Unable to create database connection: {e}")
171
+ raise RuntimeError("Unable to create database connection") from e
172
+
173
+ try:
174
+ cursor = conn.cursor()
175
+ for index_name, columns in indexes.items():
176
+ if not is_safe_identifier(index_name):
177
+ raise ValueError(f"Invalid index name: {index_name} - potential SQL injection risk!")
178
+
179
+ for column in columns:
180
+ if not is_safe_identifier(column):
181
+ raise ValueError(f"Invalid column name in index: {column} - potential SQL injection risk!")
182
+
183
+ columns_str = ', '.join(columns)
184
+ index_sql = f'''
185
+ CREATE INDEX IF NOT EXISTS {index_name} ON tb_nodes ({columns_str});
186
+ '''
187
+ cursor.execute(index_sql)
188
+ conn.commit()
189
+ except sqlite3.Error as e:
190
+ logger.error(f"Failed to add table index: {e}")
191
+ raise RuntimeError("Failed to add table index") from e
192
+ finally:
193
+ conn.close()
194
+
195
+
196
+ def post_process_db(db_path):
197
+ add_table_index(db_path)
198
+ change_mode(db_path, FileCheckConst.DATA_FILE_AUTHORITY)
199
+
200
+
201
+ def node_to_db(graph, db_name):
202
+ create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns)
203
+ insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns)
204
+ data = []
205
+ stack_dict = {}
206
+ for i, node in enumerate(graph.get_sorted_nodes()):
207
+ stack_info_text = json.dumps(node.stack_info)
208
+ if stack_info_text not in stack_dict:
209
+ stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict)
210
+ data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value,
211
+ node.upnode.id if node.upnode else '',
212
+ json.dumps([node.id for node in node.subnodes]) if node.subnodes else '',
213
+ node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL),
214
+ node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link),
215
+ stack_dict.get(stack_info_text),
216
+ json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '',
217
+ json.dumps(node.matched_distributed), 0,
218
+ json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)),
219
+ json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)),
220
+ graph.data_source, graph.data_path, graph.step, graph.rank))
221
+ to_db(db_name, create_table_sql, insert_sql, data)
222
+ stack_to_db(stack_dict, db_name)
223
+
224
+
225
+ def config_to_db(config, db_name):
226
+ create_table_sql = create_table_sql_from_dict('tb_config', config_columns)
227
+ insert_sql = create_insert_sql_from_dict('tb_config', config_columns, ignore_insert=True)
228
+ data = [("1", "compare" if config.graph_b else "build", config.task, config.tool_tip, config.micro_steps,
229
+ config.overflow_check, json.dumps(config.node_colors), json.dumps(config.rank_list),
230
+ json.dumps(config.step_list))]
231
+ to_db(db_name, create_table_sql, insert_sql, data)
232
+
233
+
234
+ def stack_to_db(stack_dict, db_name):
235
+ create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns)
236
+ insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns)
237
+ data = []
238
+ for stack_info_text, unique_id in stack_dict.items():
239
+ data.append((unique_id, stack_info_text))
240
+ to_db(db_name, create_table_sql, insert_sql, data)
241
+
242
+
243
+ def get_graph_unique_id(graph):
244
+ return f'{graph.data_source}_{graph.step}_{graph.rank}'
245
+
246
+
247
+ def get_node_unique_id(graph, node):
248
+ return f'{get_graph_unique_id(graph)}_{node.id}'
249
+
250
+
251
+ def get_stack_unique_id(graph, stack_dict):
252
+ return f'{get_graph_unique_id(graph)}_{len(stack_dict)}'
@@ -36,6 +36,8 @@ class BaseNode:
36
36
  self.overflow_level = None
37
37
  self.matched_distributed = {}
38
38
  self.batch_p2p_info = []
39
+ self.rank = 0
40
+ self.parallel_merge_info = []
39
41
 
40
42
  def __str__(self):
41
43
  info = f'id:\t{self.id}'
@@ -87,28 +89,6 @@ class BaseNode:
87
89
  self.matched_node_link = ancestors
88
90
  node.matched_node_link = ancestors
89
91
 
90
- def to_dict(self, compare_mode=None):
91
- """
92
- 输出数据
93
- """
94
- result = {
95
- 'id': self.id,
96
- 'node_type': self.op.value,
97
- 'output_data': format_node_data(self.output_data, self.id, compare_mode),
98
- 'input_data': format_node_data(self.input_data, self.id, compare_mode),
99
- 'upnode': self.upnode.id if self.upnode else 'None',
100
- 'subnodes': [node.id for node in self.subnodes],
101
- 'matched_node_link': self.matched_node_link,
102
- 'suggestions': self.suggestions,
103
- 'stack_info': self.stack_info
104
- }
105
- if self.micro_step_id is not None:
106
- result['micro_step_id'] = self.micro_step_id
107
- result['data'] = self.data
108
- if self.matched_distributed:
109
- result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
110
- return result
111
-
112
92
  def get_ancestors(self):
113
93
  """
114
94
  获取节点所有祖先的列表
@@ -82,7 +82,7 @@ class DistributedAnalyzer:
82
82
  """
83
83
  target_rank = node.input_data.get(f'{node.id}{GraphConst.INPUT}{parameter}', {}).get('value')
84
84
  if target_rank is None:
85
- logger.warning(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
85
+ logger.debug(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
86
86
  return target_rank
87
87
 
88
88
  @staticmethod
@@ -95,15 +95,15 @@ class DistributedAnalyzer:
95
95
  """
96
96
  group = node.input_data.get(f'{node.id}{GraphConst.INPUT}group', {})
97
97
  if not group:
98
- logger.warning(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
98
+ logger.debug(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
99
99
  return None, None
100
100
  group_ranks = group.get('group_ranks')
101
101
  if not group_ranks:
102
- logger.warning(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
102
+ logger.debug(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
103
103
  return None, None
104
104
  group_id = group.get('group_id')
105
105
  if not group_id:
106
- logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
106
+ logger.debug(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
107
107
  return None, None
108
108
  return group_ranks, group_id
109
109
 
@@ -183,7 +183,7 @@ class DistributedAnalyzer:
183
183
  op = info_dict.get(GraphConst.OP)
184
184
  target_rank = info_dict.get(GraphConst.PEER)
185
185
  if op is None or target_rank is None:
186
- logger.warning('Cannot get param op or peer.')
186
+ logger.debug('Cannot get param op or peer.')
187
187
  continue
188
188
  group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \
189
189
  Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '')
@@ -215,7 +215,7 @@ class DistributedAnalyzer:
215
215
  """
216
216
  target_graph = self.graphs.get(target_rank)
217
217
  if not target_graph:
218
- logger.warning(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}')
218
+ logger.debug(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}')
219
219
  return None
220
220
  target_group_mapping = self.group_node_mapping.get(target_rank)
221
221
  # p2p通信,想要获取目标节点,需要替换unique_group_id中的rank和api name,
@@ -226,7 +226,7 @@ class DistributedAnalyzer:
226
226
  target_node_id = target_group_mapping.get(target_unique_group_id, '')
227
227
  target_node = target_graph.node_map.get(target_node_id)
228
228
  if not target_node:
229
- logger.warning(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}')
229
+ logger.debug(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}')
230
230
  return None
231
231
  return target_node
232
232
 
@@ -276,13 +276,13 @@ class DistributedAnalyzer:
276
276
  source_rank = (target_node.input_data.get(f'{target_node.id}{GraphConst.INPUT}{target_config_info[1]}', {})
277
277
  .get('value'))
278
278
  if source_rank is None:
279
- logger.warning(
279
+ logger.debug(
280
280
  f'The kwarg {target_config_info[1]} of node {target_node.id} does not exist, '
281
281
  f'{CANNOT_MATCH}{source_rank}')
282
282
  return
283
283
  if source_rank != rank:
284
284
  # 点对点通信,待匹配目标节点包含的rank信息要与当前rank一致
285
- logger.warning(
285
+ logger.debug(
286
286
  f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}, '
287
287
  f'but the data shows that {target_node.id} communicates with rank{source_rank}.'
288
288
  f'The rank is inconsistent, cannot match distributed node')
@@ -291,7 +291,7 @@ class DistributedAnalyzer:
291
291
  # 点对点通信,两个匹配节点的输出数据要一致
292
292
  if not DistributedAnalyzer._node_output_all_equal(node.output_data.get(node.id + '.output.0'),
293
293
  target_node.output_data.get(target_node.id + '.output.0')):
294
- logger.warning(f'{node.id} output of rank{rank} is different from the {target_node.id} '
294
+ logger.debug(f'{node.id} output of rank{rank} is different from the {target_node.id} '
295
295
  f'output of rank{target_rank}, cannot match distributed node')
296
296
  return
297
297
 
@@ -332,7 +332,7 @@ class DistributedAnalyzer:
332
332
  if not target_group_id:
333
333
  continue
334
334
  if group_id != target_group_id:
335
- logger.warning(
335
+ logger.debug(
336
336
  f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}'
337
337
  f', but the data shows that the group id of the two nodes are different, '
338
338
  f'cannot match distributed node')
@@ -368,7 +368,7 @@ class DistributedAnalyzer:
368
368
  target_api_name = self.config.get(api_name)[0]
369
369
  target_rank = int(id_info[1].replace(Const.RANK, ''))
370
370
  except Exception as e:
371
- logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
371
+ logger.debug(f'Failed to parse batch p2p parameter with error info: {e}.')
372
372
  continue
373
373
  target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
374
374
  if not target_node:
@@ -18,16 +18,22 @@ from msprobe.visualization.graph.node_op import NodeOp
18
18
  from msprobe.visualization.utils import GraphConst
19
19
  from msprobe.core.common.log import logger
20
20
  from msprobe.core.common.const import Const
21
+ from msprobe.core.common.decorator import recursion_depth_decorator
21
22
 
22
23
 
23
24
  class Graph:
24
- def __init__(self, model_name, data_path='', dump_data=None):
25
+ def __init__(self, model_name, data_path='', dump_data=None, micro_step_num=None):
25
26
  self.node_map = {}
26
27
  self.node_id_map = {}
27
28
  self.add_node(NodeOp.module, model_name)
28
29
  self.root = self.get_node(model_name)
29
30
  self.data_path = data_path
30
31
  self.dump_data = dump_data
32
+ self.data_source = GraphConst.JSON_NPU_KEY
33
+ self.step = 0
34
+ self.rank = 0
35
+ self.compare_mode = GraphConst.SUMMARY_COMPARE
36
+ self.micro_step_num = micro_step_num
31
37
 
32
38
  def __str__(self):
33
39
  infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map]
@@ -65,8 +71,10 @@ class Graph:
65
71
  return node_b, ancestors_n, ancestors_b
66
72
 
67
73
  @staticmethod
68
- def fuzzy_match(node_n, node_b):
69
- if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
74
+ def fuzzy_match(node_n, node_b, check_shape=True):
75
+ if not node_n or not node_b:
76
+ return None, [], []
77
+ if check_shape and not node_n.fuzzy_eq(node_b):
70
78
  return None, [], []
71
79
  ancestors_n = node_n.get_ancestors()
72
80
  ancestors_b = node_b.get_ancestors()
@@ -116,6 +124,25 @@ class Graph:
116
124
  result[micro_step].append(node)
117
125
  return result
118
126
 
127
+ def get_sorted_nodes(self):
128
+ """
129
+ 通过深度优先遍历graph,获得排过序的node列表
130
+ """
131
+ visited = set()
132
+ order = []
133
+
134
+ @recursion_depth_decorator('msprobe.visualization.graph.graph.Graph.get_nodes_order.visit', max_depth=500)
135
+ def visit(node):
136
+ if node.id in visited:
137
+ return
138
+ visited.add(node.id)
139
+ for sub_node in node.subnodes:
140
+ visit(sub_node)
141
+ order.append(node)
142
+
143
+ visit(self.root)
144
+ return order
145
+
119
146
  def add_node(self, node_op, node_id, up_node=None, id_accumulation=False):
120
147
  """
121
148
  在graph中进行节点的添加
@@ -146,19 +173,6 @@ class Graph:
146
173
  """
147
174
  return self.node_map.get(node_id, None)
148
175
 
149
- def to_dict(self, compare_mode=None):
150
- """
151
- 用于数据输出
152
- """
153
- result = {}
154
- result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None'
155
- result[GraphConst.JSON_DATA_KEY] = self.data_path
156
- result[GraphConst.JSON_NODE_KEY] = {}
157
- for node_id in self.node_map:
158
- info = self.node_map.get(node_id).to_dict(compare_mode)
159
- result[GraphConst.JSON_NODE_KEY][node_id] = info
160
- return result
161
-
162
176
  def paging_by_micro_step(self, graph_other=None):
163
177
  """
164
178
  给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理
@@ -168,6 +182,18 @@ class Graph:
168
182
  graph_other: 可选参数,另一个graph
169
183
  Returns: 分批的数量
170
184
  """
185
+
186
+ @recursion_depth_decorator(
187
+ 'msprobe.visualization.graph.graph.Graph.paging_by_micro_step.propagate_micro_step_id', max_depth=500)
188
+ def propagate_micro_step_id(node):
189
+ if node.upnode is not None and node.micro_step_id is None:
190
+ node.micro_step_id = node.upnode.micro_step_id
191
+ for sub_node in node.subnodes:
192
+ propagate_micro_step_id(sub_node)
193
+
194
+ if self.micro_step_num is not None:
195
+ return self.micro_step_num + 1
196
+
171
197
  batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes)
172
198
  for batch_number, nodes in batches_n.items():
173
199
  for node in nodes:
@@ -177,6 +203,7 @@ class Graph:
177
203
  node_other = graph_other.get_node(node.matched_node_link[-1])
178
204
  if node_other:
179
205
  node_other.micro_step_id = batch_number
206
+ propagate_micro_step_id(self.root)
180
207
  # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id
181
208
  if graph_other:
182
209
  for node in graph_other.root.subnodes:
@@ -186,6 +213,7 @@ class Graph:
186
213
  except ValueError:
187
214
  micro_step_id = 0
188
215
  node.micro_step_id = micro_step_id
216
+ propagate_micro_step_id(graph_other.root)
189
217
  return len(batches_n)
190
218
 
191
219
  def overflow_check(self):