mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,13 +12,16 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  import re
16
- import math
17
- from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
17
+
18
+ from msprobe.core.compare.acc_compare import ModeConfig
19
+ from msprobe.core.compare.multiprocessing_compute import CompareRealData
20
+ from msprobe.core.compare.utils import read_op, merge_tensor, get_accuracy, make_result_table
18
21
  from msprobe.core.common.utils import set_dump_path, get_dump_mode
19
22
  from msprobe.visualization.utils import GraphConst
20
23
  from msprobe.core.common.const import Const
21
- from msprobe.core.compare.acc_compare import ModeConfig
24
+
22
25
 
23
26
  # 用于将节点名字解析成对应的NodeOp的规则
24
27
  op_patterns = [
@@ -54,13 +57,11 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
54
57
  mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
55
58
 
56
59
  if framework == Const.PT_FRAMEWORK:
57
- from msprobe.pytorch.compare.pt_compare import PTComparator
58
- return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
60
+ from msprobe.pytorch.compare.pt_compare import read_real_data
61
+ return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path)
59
62
  else:
60
- from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig
61
- ms_comparator = MSComparator(mode_config, MappingConfig())
62
- ms_comparator.cross_frame = is_cross_frame
63
- return ms_comparator.do_multi_process(dump_path_param, csv_path)
63
+ from msprobe.mindspore.compare.ms_compare import read_real_data
64
+ return CompareRealData(read_real_data, mode_config, is_cross_frame).do_multi_process(dump_path_param, csv_path)
64
65
 
65
66
 
66
67
  def get_input_output(node_data, node_id):
@@ -120,11 +121,13 @@ def compare_data_fuzzy(data_dict_list1, data_dict_list2):
120
121
  return True
121
122
 
122
123
 
123
- def format_node_data(data_dict, node_id=None):
124
+ def format_node_data(data_dict, node_id=None, compare_mode=None):
124
125
  """
125
126
  删除节点数据中不需要展示的字段
126
127
  """
127
128
  del_list = ['requires_grad', 'full_op_name']
129
+ if GraphConst.MD5_COMPARE != compare_mode:
130
+ del_list.append(Const.MD5)
128
131
  if node_id and GraphConst.BATCH_P2P in node_id:
129
132
  del_list.extend(['op', 'peer', 'tag', 'group_id'])
130
133
  for _, value in data_dict.items():
@@ -172,7 +175,7 @@ def _format_decimal_string(s):
172
175
  """
173
176
  使用正则表达式匹配包含数字、小数点和可选的百分号的字符串
174
177
  """
175
- pattern = re.compile(r'\d{1,20}\.\d{1,20}%?')
178
+ pattern = re.compile(r'^\d{1,20}\.\d{1,20}%?$')
176
179
  matches = pattern.findall(s)
177
180
  for match in matches:
178
181
  is_percent = match.endswith('%')
@@ -227,3 +230,12 @@ def _format_data(data_dict):
227
230
  if all_null:
228
231
  data_dict.clear()
229
232
  data_dict[GraphConst.VALUE] = GraphConst.NULL
233
+
234
+
235
+ def get_csv_df(stack_mode, csv_data, compare_mode):
236
+ """
237
+ 调用acc接口写入csv
238
+ """
239
+
240
+ dump_mode = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(compare_mode)
241
+ return make_result_table(csv_data, dump_mode, stack_mode)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,23 +14,27 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import re
17
- from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
18
- from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
17
+ from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data, get_csv_df
18
+ from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file
19
19
  from msprobe.visualization.graph.graph import Graph, NodeOp
20
- from msprobe.visualization.graph.node_colors import NodeColors
21
20
  from msprobe.visualization.compare.mode_adapter import ModeAdapter
22
21
  from msprobe.core.common.const import Const
22
+ from msprobe.core.common.decorator import recursion_depth_decorator
23
23
 
24
24
 
25
25
  class GraphComparator:
26
- def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
26
+ MAX_DEPTH = 1000
27
+
28
+ def __init__(self, graphs, dump_path_param, args, is_cross_framework, mapping_dict=None):
27
29
  self.graph_n = graphs[0]
28
30
  self.graph_b = graphs[1]
29
31
  self._parse_param(dump_path_param, args.output_path)
30
32
  self.framework = args.framework
33
+ self.layer_mapping = args.layer_mapping
31
34
  self.mapping_dict = mapping_dict
32
35
  self.fuzzy_match = args.fuzzy_match
33
36
  self.pattern = re.compile(r'\.\d+\.')
37
+ self.is_cross_framework = is_cross_framework
34
38
 
35
39
  def compare(self):
36
40
  """
@@ -41,7 +45,7 @@ class GraphComparator:
41
45
  else:
42
46
  self._compare_nodes(self.graph_n.root)
43
47
  self._postcompare()
44
-
48
+
45
49
  def add_compare_result_to_node(self, node, compare_result_list):
46
50
  """
47
51
  将比对结果添加到节点的输入输出数据中
@@ -66,7 +70,58 @@ class GraphComparator:
66
70
  self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
67
71
  node.data[GraphConst.JSON_INDEX_KEY] = precision_index
68
72
  node.data.update(other_dict)
69
-
73
+
74
+ def _compare_nodes(self, node_root):
75
+ """
76
+ 遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
77
+ 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
78
+ """
79
+ def compare_single_node(node_n):
80
+ if self.layer_mapping:
81
+ node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
82
+ if node_b:
83
+ ancestors_n.append(node_n.id)
84
+ ancestors_b.append(node_b.id)
85
+ node_n.matched_node_link = ancestors_b
86
+ node_b.matched_node_link = ancestors_n
87
+ else:
88
+ node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
89
+ if node_b:
90
+ ancestors.append(node_b.id)
91
+ node_n.add_link(node_b, ancestors)
92
+ if node_b:
93
+ # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
94
+ self._get_and_add_result(node_n, node_b)
95
+ node_list.extend(node_n.subnodes)
96
+
97
+ node_list = [node_root]
98
+ while node_list:
99
+ compare_single_node(node_list.pop(0))
100
+
101
+ def _compare_nodes_fuzzy(self, node_root):
102
+ def compare_single_nodes_fuzzy(node_n):
103
+ if node_n.op != NodeOp.function_api:
104
+ # 模块经过模糊匹配
105
+ node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
106
+ if node_b:
107
+ self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
108
+ # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
109
+ recount_result_n = self._recount_api_node(node_n)
110
+ recount_result_b = self._recount_api_node(node_b)
111
+ for recount_node_id, node_id_n in recount_result_n.items():
112
+ api_node_n = self.graph_n.node_map.get(node_id_n)
113
+ if not api_node_n:
114
+ continue
115
+ api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
116
+ api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
117
+ if api_node_b:
118
+ self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
119
+ node_list.extend(node_n.subnodes)
120
+
121
+ node_list = [node_root]
122
+ while node_list:
123
+ compare_single_nodes_fuzzy(node_list.pop(0))
124
+
70
125
  def _parse_param(self, dump_path_param, output_path):
71
126
  self.dump_path_param = dump_path_param
72
127
  self.output_path = output_path
@@ -81,7 +136,7 @@ class GraphComparator:
81
136
  if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
82
137
  return
83
138
  df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
84
- df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
139
+ df = run_real_data(self.dump_path_param, df, self.framework, self.is_cross_framework)
85
140
  compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
86
141
  for node in self.ma.compare_nodes:
87
142
  precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
@@ -103,49 +158,6 @@ class GraphComparator:
103
158
  else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
104
159
  node.data[GraphConst.JSON_INDEX_KEY] = precision_index
105
160
 
106
- def _compare_nodes(self, node_n):
107
- """
108
- 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
109
- 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息
110
- """
111
- if self.mapping_dict:
112
- node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_dict)
113
- if node_b:
114
- ancestors_n.append(node_n.id)
115
- ancestors_b.append(node_b.id)
116
- node_n.matched_node_link = ancestors_b
117
- node_b.matched_node_link = ancestors_n
118
- else:
119
- node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b)
120
- if node_b:
121
- ancestors.append(node_b.id)
122
- node_n.add_link(node_b, ancestors)
123
- if node_b:
124
- # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
125
- self._get_and_add_result(node_n, node_b)
126
- for subnode in node_n.subnodes:
127
- self._compare_nodes(subnode)
128
-
129
- def _compare_nodes_fuzzy(self, node_n):
130
- if node_n.op != NodeOp.function_api:
131
- # 模块经过模糊匹配
132
- node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
133
- if node_b:
134
- self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
135
- # 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
136
- recount_result_n = self._recount_api_node(node_n)
137
- recount_result_b = self._recount_api_node(node_b)
138
- for recount_node_id, node_id_n in recount_result_n.items():
139
- api_node_n = self.graph_n.node_map.get(node_id_n)
140
- if not api_node_n:
141
- continue
142
- api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
143
- api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
144
- if api_node_b:
145
- self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
146
- for sub_node in node_n.subnodes:
147
- self._compare_nodes_fuzzy(sub_node)
148
-
149
161
  def _get_and_add_result(self, node_n, node_b):
150
162
  compare_result_list = compare_node([node_n.id, node_b.id],
151
163
  [self.data_n_dict, self.data_b_dict],
@@ -13,8 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import json
17
16
  import math
17
+ import json
18
18
  from msprobe.core.common.const import CompareConst, Const
19
19
  from msprobe.visualization.utils import ToolTip, GraphConst, str2float
20
20
 
@@ -25,6 +25,12 @@ class ModeAdapter:
25
25
  self.csv_data = []
26
26
  self.compare_nodes = []
27
27
 
28
+ @staticmethod
29
+ def _is_invalid(value):
30
+ if not isinstance(value, float):
31
+ return False
32
+ return math.isnan(value) or math.isinf(value)
33
+
28
34
  @staticmethod
29
35
  def _add_md5_compare_data(node_data, compare_data_dict):
30
36
  precision_index = GraphConst.MAX_INDEX_KEY
@@ -49,6 +55,8 @@ class ModeAdapter:
49
55
  for key, value in node_data.items():
50
56
  if not isinstance(value, dict):
51
57
  continue
58
+ if value.get(Const.MAX) is None:
59
+ continue
52
60
  compare_data = compare_data_dict.get(key)
53
61
  if compare_data:
54
62
  headers = CompareConst.COMPARE_RESULT_HEADER
@@ -67,9 +75,13 @@ class ModeAdapter:
67
75
  if thousandth is not None:
68
76
  numbers.append(thousandth)
69
77
  node_data[key] = value
78
+ if ModeAdapter._is_invalid(value.get(Const.MAX)) or ModeAdapter._is_invalid(value.get(Const.MIN)):
79
+ numbers.append(CompareConst.N_A)
70
80
  # 双千指标都是None的异常情况
71
81
  if not numbers:
72
82
  min_thousandth = None
83
+ elif CompareConst.N_A in numbers:
84
+ min_thousandth = CompareConst.N_A
73
85
  else:
74
86
  min_thousandth = min(numbers + [min_thousandth])
75
87
  return min_thousandth
@@ -81,6 +93,8 @@ class ModeAdapter:
81
93
  for key, data_info in node_data.items():
82
94
  if not isinstance(data_info, dict):
83
95
  continue
96
+ if data_info.get(Const.MAX) is None:
97
+ continue
84
98
  compare_data = compare_data_dict.get(key)
85
99
  if compare_data:
86
100
  # 对应比对结果csv的列
@@ -92,6 +106,8 @@ class ModeAdapter:
92
106
  relative_err = str2float(data_info.get(item))
93
107
  max_relative_err = max(max_relative_err, relative_err)
94
108
  node_data[key] = data_info
109
+ if ModeAdapter._is_invalid(data_info.get(Const.MAX)) or ModeAdapter._is_invalid(data_info.get(Const.MIN)):
110
+ max_relative_err = GraphConst.MAX_INDEX_KEY
95
111
  max_relative_err = 1 if max_relative_err > 1 else max_relative_err
96
112
  return max_relative_err
97
113
 
@@ -133,7 +149,11 @@ class ModeAdapter:
133
149
  ModeAdapter._check_list_len(compare_data_dict_list, 1)
134
150
  min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0])
135
151
  min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0])
136
- if min_thousandth_in is not None and min_thousandth_out is not None:
152
+ if CompareConst.N_A == min_thousandth_out:
153
+ change_percentage = GraphConst.MAX_INDEX_KEY
154
+ elif CompareConst.N_A == min_thousandth_in:
155
+ change_percentage = GraphConst.MIN_INDEX_KEY
156
+ elif min_thousandth_in is not None and min_thousandth_out is not None:
137
157
  change_percentage = min_thousandth_in - min_thousandth_out
138
158
  else:
139
159
  change_percentage = GraphConst.MIN_INDEX_KEY
@@ -157,24 +177,6 @@ class ModeAdapter:
157
177
  return
158
178
  self.csv_data.extend(compare_result_list)
159
179
 
160
- def add_error_key(self, node_data):
161
- """
162
- 根据不同的模式进行提供不同错误信息
163
- """
164
- for key, value in node_data.items():
165
- if not isinstance(value, dict):
166
- continue
167
- if self.compare_mode == GraphConst.SUMMARY_COMPARE:
168
- message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
169
- CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
170
- elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
171
- message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
172
- else:
173
- # 输出件优化
174
- message = []
175
- value[GraphConst.ERROR_KEY] = message
176
- node_data[key] = value
177
-
178
180
  def get_tool_tip(self):
179
181
  """
180
182
  用于前端展示字段的具体含义
@@ -12,10 +12,11 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+
15
16
  from msprobe.core.overflow_check.level import OverflowLevel
16
- from msprobe.visualization.graph.node_op import NodeOp
17
17
  from msprobe.visualization.utils import GraphConst
18
18
  from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
19
+ from msprobe.core.common.log import logger
19
20
 
20
21
 
21
22
  class BaseNode:
@@ -86,15 +87,15 @@ class BaseNode:
86
87
  self.matched_node_link = ancestors
87
88
  node.matched_node_link = ancestors
88
89
 
89
- def to_dict(self):
90
+ def to_dict(self, compare_mode=None):
90
91
  """
91
92
  输出数据
92
93
  """
93
94
  result = {
94
95
  'id': self.id,
95
96
  'node_type': self.op.value,
96
- 'output_data': format_node_data(self.output_data, self.id),
97
- 'input_data': format_node_data(self.input_data, self.id),
97
+ 'output_data': format_node_data(self.output_data, self.id, compare_mode),
98
+ 'input_data': format_node_data(self.input_data, self.id, compare_mode),
98
99
  'upnode': self.upnode.id if self.upnode else 'None',
99
100
  'subnodes': [node.id for node in self.subnodes],
100
101
  'matched_node_link': self.matched_node_link,
@@ -114,7 +115,13 @@ class BaseNode:
114
115
  """
115
116
  ancestors = []
116
117
  current_node = self.upnode
118
+ seen_nodes = set()
117
119
  while current_node:
120
+ if current_node.id in seen_nodes:
121
+ logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, '
122
+ f'current node is {current_node.id}.')
123
+ return []
124
+ seen_nodes.add(current_node.id)
118
125
  ancestors.append(current_node.id)
119
126
  current_node = current_node.upnode
120
127
  return list(reversed(ancestors))
@@ -107,15 +107,6 @@ class DistributedAnalyzer:
107
107
  return None, None
108
108
  return group_ranks, group_id
109
109
 
110
- @staticmethod
111
- def _get_batch_group_info(node, rank):
112
- for data in node.input_data.values():
113
- group_id = data.get('group_id')
114
- if group_id is not None:
115
- return group_id
116
- logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
117
- return None
118
-
119
110
  def distributed_match(self):
120
111
  for rank, graph in self.graphs.items():
121
112
  nodes = graph.node_map
@@ -377,7 +368,7 @@ class DistributedAnalyzer:
377
368
  target_api_name = self.config.get(api_name)[0]
378
369
  target_rank = int(id_info[1].replace(Const.RANK, ''))
379
370
  except Exception as e:
380
- logger.warning(f'Failed to parsing batch p2p parameter with error info: {e}.')
371
+ logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
381
372
  continue
382
373
  target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
383
374
  if not target_node:
@@ -20,9 +20,6 @@ from msprobe.core.common.log import logger
20
20
  from msprobe.core.common.const import Const
21
21
 
22
22
 
23
- MAX_RECUR_LEVEL = 100
24
-
25
-
26
23
  class Graph:
27
24
  def __init__(self, model_name, data_path='', dump_data=None):
28
25
  self.node_map = {}
@@ -67,7 +64,6 @@ class Graph:
67
64
  ancestors_b = node_b.get_ancestors()
68
65
  return node_b, ancestors_n, ancestors_b
69
66
 
70
-
71
67
  @staticmethod
72
68
  def fuzzy_match(node_n, node_b):
73
69
  if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
@@ -76,13 +72,6 @@ class Graph:
76
72
  ancestors_b = node_b.get_ancestors()
77
73
  return node_b, ancestors_n, ancestors_b
78
74
 
79
- @staticmethod
80
- def dfs(node, result):
81
- info = node.to_dict()
82
- result[node.id] = info
83
- for subnode in node.subnodes:
84
- Graph.dfs(subnode, result)
85
-
86
75
  @staticmethod
87
76
  def split_nodes_by_micro_step(nodes):
88
77
  """
@@ -157,7 +146,7 @@ class Graph:
157
146
  """
158
147
  return self.node_map.get(node_id, None)
159
148
 
160
- def to_dict(self):
149
+ def to_dict(self, compare_mode=None):
161
150
  """
162
151
  用于数据输出
163
152
  """
@@ -166,7 +155,7 @@ class Graph:
166
155
  result[GraphConst.JSON_DATA_KEY] = self.data_path
167
156
  result[GraphConst.JSON_NODE_KEY] = {}
168
157
  for node_id in self.node_map:
169
- info = self.node_map.get(node_id).to_dict()
158
+ info = self.node_map.get(node_id).to_dict(compare_mode)
170
159
  result[GraphConst.JSON_NODE_KEY][node_id] = info
171
160
  return result
172
161
 
@@ -24,7 +24,6 @@ class NodeOp(Enum):
24
24
  function_api = 1
25
25
  api_collection = 9
26
26
 
27
-
28
27
  @staticmethod
29
28
  def get_node_op(node_name: str):
30
29
  """
@@ -37,5 +36,5 @@ class NodeOp(Enum):
37
36
  pattern = op_patterns[index]
38
37
  if re.match(pattern, node_name):
39
38
  return op
40
- logger.warning(f"Cannot parsing node_name {node_name} into NodeOp, default parsing as module.")
39
+ logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.")
41
40
  return NodeOp.module