mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,16 +14,24 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import re
17
+
18
+ from msprobe.core.common.const import Const
19
+ from msprobe.core.common.file_utils import load_json
20
+ from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
+ from msprobe.visualization.builder.msprobe_adapter import op_patterns
17
22
  from msprobe.visualization.graph.graph import Graph
18
23
  from msprobe.visualization.graph.node_op import NodeOp
19
24
  from msprobe.visualization.utils import save_json_file, GraphConst
20
- from msprobe.visualization.builder.msprobe_adapter import get_input_output
21
- from msprobe.core.common.file_utils import load_json
22
25
 
23
26
 
24
27
  class GraphBuilder:
28
+ backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
29
+ forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
30
+ # 匹配以大写字母开头,后接任意字母,并以Template(结尾
31
+ template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
32
+
25
33
  @staticmethod
26
- def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
34
+ def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
27
35
  """
28
36
  GraphBuilder的对外提供的构图方法
29
37
  Args:
@@ -31,11 +39,14 @@ class GraphBuilder:
31
39
  data_path: dump.json路径
32
40
  stack_path: stack.json路径
33
41
  model_name: 模型名字,依赖外部输入
42
+ complete_stack: 完整的堆栈信息
34
43
  Returns: Graph,代表图的数据结构
35
44
  """
36
45
  construct_dict = load_json(construct_path)
37
46
  dump_dict = load_json(data_path)
38
47
  stack_dict = load_json(stack_path)
48
+ if not complete_stack:
49
+ GraphBuilder._simplify_stack(stack_dict)
39
50
  data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
40
51
  graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
41
52
  GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
@@ -61,20 +72,59 @@ class GraphBuilder:
61
72
  result[GraphConst.MICRO_STEPS] = config.micro_steps
62
73
  if config.task:
63
74
  result[GraphConst.JSON_TASK_KEY] = config.task
75
+ result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
64
76
  save_json_file(filename, result)
65
77
 
78
+ @staticmethod
79
+ def _simplify_stack(stack_dict):
80
+ """
81
+ 精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
82
+
83
+ 例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
84
+ 保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
85
+
86
+ 例如Api Tensor.__iadd__.4.forward,堆栈为:
87
+ "File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
88
+ "File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
89
+ 匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
90
+ """
91
+ module_pattern = re.compile(op_patterns[0])
92
+ for dump_name, stack_list in stack_dict.items():
93
+ if not isinstance(stack_list, list):
94
+ continue
95
+ if module_pattern.match(dump_name):
96
+ parts = dump_name.split(Const.SEP)
97
+ if len(parts) < abs(Const.LAYER_NAME_INDEX):
98
+ continue
99
+ module_name = parts[Const.LAYER_NAME_INDEX]
100
+ for stack in stack_list:
101
+ if re.search(module_name + r'\(', stack):
102
+ stack_list = [stack]
103
+ break
104
+ else:
105
+ for index, stack in enumerate(stack_list):
106
+ if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
107
+ stack_list = [stack_list[index + 1]]
108
+ break
109
+ stack_dict[dump_name] = stack_list
110
+
66
111
  @staticmethod
67
112
  def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
68
113
  """
69
114
  如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
70
115
  """
71
116
  # 匹配以.backward.后跟一个或多个数字结尾的模式
72
- backward_pattern = r"(\.backward\.)(\d+)$"
73
- forward_pattern = r"(\.forward\.)(\d+)$"
74
- if re.search(backward_pattern, subnode_id) and not upnode_id:
75
- forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id))
117
+ if GraphBuilder.backward_pattern.search(subnode_id) and not upnode_id:
118
+ forward_upnode_id = construct_dict.get(GraphBuilder.backward_pattern.sub(r".forward.\2", subnode_id))
119
+ if forward_upnode_id:
120
+ new_upnode_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_upnode_id)
121
+ if new_upnode_id in construct_dict:
122
+ return new_upnode_id
123
+ # 匹配以.backward结尾的节点
124
+ if subnode_id.endswith(Const.SEP + Const.BACKWARD) and not upnode_id:
125
+ forward_upnode_id = construct_dict.get(subnode_id.replace(Const.BACKWARD, Const.FORWARD))
76
126
  if forward_upnode_id:
77
- new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id)
127
+ new_upnode_id = forward_upnode_id.replace(Const.FORWARD, Const.BACKWARD)
78
128
  if new_upnode_id in construct_dict:
79
129
  return new_upnode_id
80
130
  return upnode_id
@@ -104,11 +154,42 @@ class GraphBuilder:
104
154
  input_data, output_data = get_input_output(node_data, node.id)
105
155
  # 更新数据
106
156
  node.set_input_output(input_data, output_data)
157
+ if GraphConst.BATCH_P2P in name:
158
+ GraphBuilder._extract_batch_p2p_info(node, node_data)
159
+ # 反向节点使用对应前向节点的堆栈信息
160
+ # 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
161
+ if (not node_stack_info and
162
+ (GraphBuilder.backward_pattern.search(name) or name.endswith(f'{Const.SEP}{Const.BACKWARD}'))):
163
+ forward_node = graph.get_node(
164
+ # 同名模块全局唯一,无论调用几次堆栈信息都一致,直接使用编号0的同名模块堆栈信息,避免遗漏
165
+ GraphBuilder.backward_pattern.sub(f'{Const.SEP}{Const.FORWARD}{Const.SEP}0', name)) \
166
+ if GraphBuilder.backward_pattern.search(name) \
167
+ else graph.get_node(name.replace(Const.BACKWARD, Const.FORWARD))
168
+ node_stack_info = forward_node.stack_info if forward_node \
169
+ else ['This backward node cannot find the forward node and cannot retrieve stack information.']
107
170
  node.stack_info = node_stack_info
108
171
  # 添加节点
109
172
  node.add_upnode(upnode)
110
173
  return node
111
174
 
175
+ @staticmethod
176
+ def _is_valid_batch_p2p_output(param_list):
177
+ if not isinstance(param_list, list) or not param_list:
178
+ return False
179
+ if not isinstance(param_list[0], list) or not param_list[0]:
180
+ return False
181
+ return True
182
+
183
+ @staticmethod
184
+ def _extract_batch_p2p_info(node, node_data):
185
+ param_list = node_data.get(Const.OUTPUT, [])
186
+ # 数据格式:"output": [[{param1}, {param2}, ...]]
187
+ if GraphBuilder._is_valid_batch_p2p_output(param_list):
188
+ for param in param_list[0]:
189
+ info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
190
+ GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
191
+ node.batch_p2p_info.append(info)
192
+
112
193
  @staticmethod
113
194
  def _collect_apis_between_modules(graph):
114
195
  """
@@ -156,10 +237,12 @@ class GraphBuilder:
156
237
 
157
238
 
158
239
  class GraphExportConfig:
159
- def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''):
240
+ def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
241
+ overflow_check=False):
160
242
  self.graph_n = graph_n
161
243
  self.graph_b = graph_b
162
244
  self.tool_tip = tool_tip
163
245
  self.node_colors = node_colors
164
246
  self.micro_steps = micro_steps
165
247
  self.task = task
248
+ self.overflow_check = overflow_check
@@ -18,11 +18,12 @@ from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
18
18
  from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
19
  from msprobe.visualization.utils import GraphConst
20
20
  from msprobe.core.common.const import Const
21
+ from msprobe.core.compare.acc_compare import ModeConfig
21
22
 
22
23
  # 用于将节点名字解析成对应的NodeOp的规则
23
24
  op_patterns = [
24
25
  # NodeOp.module
25
- r'^(Module.|Cell.)',
26
+ r'^(Module.|Cell.|optimizer|clip_grad)',
26
27
  # NodeOp.function_api
27
28
  r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
28
29
  ]
@@ -50,12 +51,14 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
50
51
  framework: 框架类型, pytorch或mindspore
51
52
  is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
52
53
  """
54
+ mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
55
+
53
56
  if framework == Const.PT_FRAMEWORK:
54
57
  from msprobe.pytorch.compare.pt_compare import PTComparator
55
- return PTComparator().do_multi_process(dump_path_param, csv_path)
58
+ return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
56
59
  else:
57
- from msprobe.mindspore.compare.ms_compare import MSComparator
58
- ms_comparator = MSComparator()
60
+ from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig
61
+ ms_comparator = MSComparator(mode_config, MappingConfig())
59
62
  ms_comparator.cross_frame = is_cross_frame
60
63
  return ms_comparator.do_multi_process(dump_path_param, csv_path)
61
64
 
@@ -105,11 +108,25 @@ def compare_data(data_dict_list1, data_dict_list2):
105
108
  return True
106
109
 
107
110
 
108
- def format_node_data(data_dict):
111
+ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
112
+ """
113
+ 模糊匹配,仅校验参数shape是否一致
114
+ """
115
+ for x, y in zip(data_dict_list1.values(), data_dict_list2.values()):
116
+ x_shape = x.get(Const.SHAPE)
117
+ y_shape = y.get(Const.SHAPE)
118
+ if x_shape != y_shape:
119
+ return False
120
+ return True
121
+
122
+
123
+ def format_node_data(data_dict, node_id=None):
109
124
  """
110
- 批量进行节点数据的输出
125
+ 删除节点数据中不需要展示的字段
111
126
  """
112
127
  del_list = ['requires_grad', 'full_op_name']
128
+ if node_id and GraphConst.BATCH_P2P in node_id:
129
+ del_list.extend(['op', 'peer', 'tag', 'group_id'])
113
130
  for _, value in data_dict.items():
114
131
  if not isinstance(value, dict):
115
132
  continue
@@ -179,6 +196,13 @@ def _format_data(data_dict):
179
196
  """
180
197
  pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
181
198
  all_null = False
199
+
200
+ keys_to_keep = ['type', 'group_ranks', 'group_id', 'data_name']
201
+ if data_dict.get('type') == 'torch.ProcessGroup':
202
+ keys_to_remove = [key for key in data_dict if key not in keys_to_keep]
203
+ for key in keys_to_remove:
204
+ del data_dict[key]
205
+
182
206
  for key, value in data_dict.items():
183
207
  if isinstance(value, str):
184
208
  # 将单引号删掉,None换成null避免前端解析错误
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import re
16
17
  from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
17
18
  from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
18
19
  from msprobe.visualization.graph.graph import Graph, NodeOp
@@ -22,18 +23,23 @@ from msprobe.core.common.const import Const
22
23
 
23
24
 
24
25
  class GraphComparator:
25
- def __init__(self, graphs, dump_path_param, output_path, framework=Const.PT_FRAMEWORK, mapping_dict=None):
26
+ def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
26
27
  self.graph_n = graphs[0]
27
28
  self.graph_b = graphs[1]
28
- self._parse_param(dump_path_param, output_path)
29
- self.framework = framework
29
+ self._parse_param(dump_path_param, args.output_path)
30
+ self.framework = args.framework
30
31
  self.mapping_dict = mapping_dict
32
+ self.fuzzy_match = args.fuzzy_match
33
+ self.pattern = re.compile(r'\.\d+\.')
31
34
 
32
35
  def compare(self):
33
36
  """
34
37
  比较函数,初始化结束后单独调用。比较结果写入graph_n
35
38
  """
36
- self._compare_nodes(self.graph_n.root)
39
+ if self.fuzzy_match:
40
+ self._compare_nodes_fuzzy(self.graph_n.root)
41
+ else:
42
+ self._compare_nodes(self.graph_n.root)
37
43
  self._postcompare()
38
44
 
39
45
  def add_compare_result_to_node(self, node, compare_result_list):
@@ -60,8 +66,6 @@ class GraphComparator:
60
66
  self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
61
67
  node.data[GraphConst.JSON_INDEX_KEY] = precision_index
62
68
  node.data.update(other_dict)
63
- if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
64
- node.get_suggestions()
65
69
 
66
70
  def _parse_param(self, dump_path_param, output_path):
67
71
  self.dump_path_param = dump_path_param
@@ -82,8 +86,6 @@ class GraphComparator:
82
86
  for node in self.ma.compare_nodes:
83
87
  precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
84
88
  node.data[GraphConst.JSON_INDEX_KEY] = precision_index
85
- if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
86
- node.get_suggestions()
87
89
 
88
90
  def _handle_api_collection_index(self):
89
91
  """
@@ -120,11 +122,59 @@ class GraphComparator:
120
122
  node_n.add_link(node_b, ancestors)
121
123
  if node_b:
122
124
  # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
123
- compare_result_list = compare_node([node_n.id, node_b.id],
124
- [self.data_n_dict, self.data_b_dict],
125
- self.stack_json_data, self.ma.compare_mode)
126
- if compare_result_list:
127
- self.ma.add_csv_data(compare_result_list)
128
- self.add_compare_result_to_node(node_n, compare_result_list)
125
+ self._get_and_add_result(node_n, node_b)
129
126
  for subnode in node_n.subnodes:
130
127
  self._compare_nodes(subnode)
128
+
129
+ def _compare_nodes_fuzzy(self, node_n):
130
+ if node_n.op != NodeOp.function_api:
131
+ # 模块经过模糊匹配
132
+ node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
133
+ if node_b:
134
+ self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
135
+ # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
136
+ recount_result_n = self._recount_api_node(node_n)
137
+ recount_result_b = self._recount_api_node(node_b)
138
+ for recount_node_id, node_id_n in recount_result_n.items():
139
+ api_node_n = self.graph_n.node_map.get(node_id_n)
140
+ if not api_node_n:
141
+ continue
142
+ api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
143
+ api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
144
+ if api_node_b:
145
+ self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
146
+ for sub_node in node_n.subnodes:
147
+ self._compare_nodes_fuzzy(sub_node)
148
+
149
+ def _get_and_add_result(self, node_n, node_b):
150
+ compare_result_list = compare_node([node_n.id, node_b.id],
151
+ [self.data_n_dict, self.data_b_dict],
152
+ self.stack_json_data, self.ma.compare_mode)
153
+ if compare_result_list:
154
+ self.ma.add_csv_data(compare_result_list)
155
+ self.add_compare_result_to_node(node_n, compare_result_list)
156
+
157
+ def _recount_api_node(self, node):
158
+ """
159
+ 两个匹配上的模块, 忽略各自模块下所有api的dump调用次数, 并赋予模块中的调用顺序
160
+ Return:
161
+ {赋予模块中的调用顺序的node_id: 原始node_id}
162
+ """
163
+ recount_result = {}
164
+ node_count = {}
165
+ for sub_node in node.subnodes:
166
+ if sub_node.op == NodeOp.function_api:
167
+ # 忽略dump调用次数
168
+ count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
169
+ node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
170
+ # 赋予模块中的调用顺序
171
+ recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
172
+ recount_result[recount_node_id] = sub_node.id
173
+ return recount_result
174
+
175
+ def _process_matched_nodes(self, node_n, node_b, ancestors_n, ancestors_b):
176
+ ancestors_n.append(node_n.id)
177
+ ancestors_b.append(node_b.id)
178
+ node_n.matched_node_link = ancestors_b
179
+ node_b.matched_node_link = ancestors_n
180
+ self._get_and_add_result(node_n, node_b)
@@ -83,27 +83,13 @@ class ModeAdapter:
83
83
  continue
84
84
  compare_data = compare_data_dict.get(key)
85
85
  if compare_data:
86
- dtype = data_info.get(Const.DTYPE)
87
86
  # 对应比对结果csv的列
88
87
  key_list = GraphConst.SUMMARY_INDEX_LIST
89
88
  headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
90
89
  id_list = [headers.index(x) for x in key_list]
91
90
  ModeAdapter._match_data(data_info, compare_data, key_list, id_list)
92
- for index, item in enumerate(key_list[4:]):
93
- value = data_info.get(GraphConst.VALUE_INDEX_LIST[index])
94
- value_diff = data_info.get(key_list[index])
91
+ for item in key_list[4:]:
95
92
  relative_err = str2float(data_info.get(item))
96
- if isinstance(value, float) and isinstance(value_diff, float) \
97
- and dtype in GraphConst.SMALL_VALUES.keys():
98
- small_value = GraphConst.SMALL_VALUES.get(dtype)
99
- # 小值域
100
- if abs(value) <= small_value:
101
- data_info[item] = ToolTip.SMALL_VALUE_TIP.format(data_info.get(item),
102
- GraphConst.VALUE_INDEX_LIST[index],
103
- small_value)
104
- relative_err = GraphConst.MIN_INDEX_KEY \
105
- if abs(value_diff) <= GraphConst.SMALL_VALUES_ABS_ERROR.get(dtype) \
106
- else GraphConst.MAX_INDEX_KEY
107
93
  max_relative_err = max(max_relative_err, relative_err)
108
94
  node_data[key] = data_info
109
95
  max_relative_err = 1 if max_relative_err > 1 else max_relative_err
@@ -14,8 +14,8 @@
14
14
  # limitations under the License.
15
15
  from msprobe.core.overflow_check.level import OverflowLevel
16
16
  from msprobe.visualization.graph.node_op import NodeOp
17
- from msprobe.visualization.utils import Suggestions, GraphConst
18
- from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data
17
+ from msprobe.visualization.utils import GraphConst
18
+ from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
19
19
 
20
20
 
21
21
  class BaseNode:
@@ -33,6 +33,8 @@ class BaseNode:
33
33
  self.stack_info = []
34
34
  self.micro_step_id = None
35
35
  self.overflow_level = None
36
+ self.matched_distributed = {}
37
+ self.batch_p2p_info = []
36
38
 
37
39
  def __str__(self):
38
40
  info = f'id:\t{self.id}'
@@ -48,16 +50,12 @@ class BaseNode:
48
50
  return False
49
51
  return True
50
52
 
51
- def get_suggestions(self):
52
- """
53
- 精度疑似有问题时,提供一些建议
54
- """
55
- if self.op == NodeOp.module:
56
- self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module
57
- self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL
58
- elif self.op == NodeOp.function_api:
59
- self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API
60
- self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL
53
+ def fuzzy_eq(self, other):
54
+ if not compare_data_fuzzy(self.input_data, other.input_data):
55
+ return False
56
+ if not compare_data_fuzzy(self.output_data, other.output_data):
57
+ return False
58
+ return True
61
59
 
62
60
  def set_input_output(self, input_data, output_data):
63
61
  self.input_data = input_data
@@ -67,6 +65,7 @@ class BaseNode:
67
65
  if not level or not isinstance(level, OverflowLevel):
68
66
  return
69
67
  self.overflow_level = level
68
+ self.data[GraphConst.OVERFLOW_LEVEL] = self.overflow_level.value
70
69
 
71
70
  def add_upnode(self, node):
72
71
  """
@@ -94,8 +93,8 @@ class BaseNode:
94
93
  result = {
95
94
  'id': self.id,
96
95
  'node_type': self.op.value,
97
- 'output_data': format_node_data(self.output_data),
98
- 'input_data': format_node_data(self.input_data),
96
+ 'output_data': format_node_data(self.output_data, self.id),
97
+ 'input_data': format_node_data(self.input_data, self.id),
99
98
  'upnode': self.upnode.id if self.upnode else 'None',
100
99
  'subnodes': [node.id for node in self.subnodes],
101
100
  'matched_node_link': self.matched_node_link,
@@ -104,12 +103,9 @@ class BaseNode:
104
103
  }
105
104
  if self.micro_step_id is not None:
106
105
  result['micro_step_id'] = self.micro_step_id
107
- # 是否存在overflow,并保存结果
108
- if self.overflow_level and isinstance(self.overflow_level, OverflowLevel):
109
- if self.data is None:
110
- self.data = dict()
111
- self.data['overflow_level'] = self.overflow_level.value
112
106
  result['data'] = self.data
107
+ if self.matched_distributed:
108
+ result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
113
109
  return result
114
110
 
115
111
  def get_ancestors(self):