mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -20,89 +20,108 @@ from collections import namedtuple
20
20
  import numpy as np
21
21
  import openpyxl
22
22
  from openpyxl.styles import PatternFill
23
+ from tqdm import tqdm
23
24
  from msprobe.core.common.utils import get_header_index
24
25
  from msprobe.core.common.file_utils import save_workbook
25
26
  from msprobe.core.common.log import logger
26
- from msprobe.core.common.const import CompareConst, FileCheckConst
27
+ from msprobe.core.common.const import CompareConst, FileCheckConst, Const
28
+ from msprobe.core.common.utils import safe_get_value
27
29
 
28
30
 
29
31
  class HighlightCheck(abc.ABC):
30
32
  @abc.abstractmethod
31
- def apply(self, info, color_columns, summary_compare):
33
+ def apply(self, info, color_columns, dump_mode):
32
34
  raise NotImplementedError
33
35
 
34
36
 
37
+ def add_highlight_row_info(color_list, num, highlight_err_msg):
38
+ for i, (existing_num, existing_err_msg) in enumerate(color_list):
39
+ if num == existing_num:
40
+ color_list[i][1].append(highlight_err_msg)
41
+ return
42
+ color_list.append((num, [highlight_err_msg]))
43
+
44
+
35
45
  class CheckOrderMagnitude(HighlightCheck):
36
46
  """检查Max diff的数量级差异"""
37
- def apply(self, info, color_columns, summary_compare=True):
47
+ def apply(self, info, color_columns, dump_mode):
38
48
  api_in, api_out, num = info
39
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
49
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
50
+ else CompareConst.MAX_ABS_ERR, dump_mode)
40
51
  if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
41
52
  return
42
53
  in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
43
54
  out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
44
55
  if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
45
- color_columns.yellow.append(num)
56
+ add_highlight_row_info(color_columns.yellow, num,
57
+ "maximum absolute error of both input and output exceed 1, "
58
+ "with the output larger by an order of magnitude")
46
59
 
47
60
 
48
61
  class CheckOneThousandErrorRatio(HighlightCheck):
49
62
  """检查千分误差比率"""
50
- def apply(self, info, color_columns, summary_compare=True):
63
+ def apply(self, info, color_columns, dump_mode):
51
64
  api_in, api_out, num = info
52
- one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
65
+ one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
53
66
  if (not isinstance(api_in[one_thousand_index], (float, int)) or
54
67
  not isinstance(api_out[one_thousand_index], (float, int))):
55
68
  return
56
69
  if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
57
70
  api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
58
- color_columns.red.append(num)
71
+ add_highlight_row_info(color_columns.red, num,
72
+ "The input's one thousandth err ratio exceeds 0.9, while the output's is below 0.6")
59
73
  elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
60
- color_columns.yellow.append(num)
74
+ add_highlight_row_info(color_columns.yellow, num,
75
+ "The output's one thousandth err ratio decreases by more than 0.1 "
76
+ "compared to the input's")
61
77
 
62
78
 
63
79
  class CheckCosineSimilarity(HighlightCheck):
64
80
  """检查余弦相似度"""
65
- def apply(self, info, color_columns, summary_compare=True):
81
+ def apply(self, info, color_columns, dump_mode):
66
82
  api_in, api_out, num = info
67
- cosine_index = get_header_index('Cosine', summary_compare)
83
+ cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
68
84
  if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
69
85
  return
70
86
  if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
71
- color_columns.yellow.append(num)
87
+ add_highlight_row_info(color_columns.yellow, num,
88
+ "The output's cosine decreases by more than 0.1 compared to the input's")
72
89
 
73
90
 
74
91
  class CheckMaxRelativeDiff(HighlightCheck):
75
92
  """检查最大相对差异"""
76
- def apply(self, info, color_columns, summary_compare=True):
93
+ def apply(self, info, color_columns, dump_mode):
77
94
  api_in, api_out, num = info
78
- max_diff_index = get_header_index('Max diff', summary_compare)
79
- bench_max_index = get_header_index('Bench max', summary_compare)
95
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
96
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
80
97
  input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
81
98
  output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
82
99
  if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
83
100
  (float, int)):
84
101
  return
85
102
  if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
86
- color_columns.red.append(num)
103
+ add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
87
104
  elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
88
105
  input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
89
- color_columns.yellow.append(num)
106
+ add_highlight_row_info(color_columns.yellow, num,
107
+ "The output's maximum relative error exceeds 0.1, while the input's is below 0.01")
90
108
 
91
109
 
92
110
  class CheckOverflow(HighlightCheck):
93
111
  """检查是否存在溢出"""
94
- def apply(self, info, color_columns, summary_compare=True):
112
+ def apply(self, info, color_columns, dump_mode):
95
113
  line, num = info
96
- npu_max_index = get_header_index('NPU max', summary_compare)
97
- npu_min_index = get_header_index('NPU min', summary_compare)
98
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
114
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
115
+ npu_min_index = get_header_index(CompareConst.NPU_MIN, dump_mode)
116
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
117
+ else CompareConst.MAX_ABS_ERR, dump_mode)
99
118
  if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
100
119
  line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
101
- color_columns.red.append(num)
120
+ add_highlight_row_info(color_columns.red, num, "maximum or minimum is nan, -inf, or inf")
102
121
  return
103
122
  # check if Max_Diff > 1e+10
104
- if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
105
- color_columns.red.append(num)
123
+ if isinstance(line[max_diff_index], (float, int)) and abs(line[max_diff_index]) > CompareConst.MAX_DIFF_RED:
124
+ add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
106
125
 
107
126
 
108
127
  class HighlightRules:
@@ -124,13 +143,14 @@ class HighlightRules:
124
143
  }
125
144
 
126
145
 
127
- def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
146
+ def find_error_rows(result, last_len, n_num_input, highlight_dict, dump_mode):
128
147
  """找到单个API中需要高亮的行"""
129
- if md5_compare:
148
+ if dump_mode == Const.MD5:
130
149
  return
131
- npu_max_index = get_header_index('NPU max', summary_compare)
132
- bench_max_index = get_header_index('Bench max', summary_compare)
133
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
150
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
151
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
152
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
153
+ else CompareConst.MAX_ABS_ERR, dump_mode)
134
154
 
135
155
  red_lines, yellow_lines = [], []
136
156
  LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
@@ -143,7 +163,7 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
143
163
  num = last_len + i
144
164
  line_info = LineInfo(line_data=line, num_pointer=num)
145
165
  for rule in HighlightRules.basic_rules.values():
146
- rule.apply(line_info, color_columns, summary_compare)
166
+ rule.apply(line_info, color_columns, dump_mode)
147
167
 
148
168
  # 对API的输出与输入比较,进行误差判断
149
169
  for n, api_out in enumerate(result[n_num_input:len(result)]):
@@ -161,36 +181,42 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
161
181
  continue
162
182
 
163
183
  api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
164
- if summary_compare:
184
+ if dump_mode == Const.SUMMARY:
165
185
  for rule in HighlightRules.summary_compare_rules.values():
166
- rule.apply(api_info, color_columns, summary_compare)
186
+ rule.apply(api_info, color_columns, dump_mode)
167
187
  else:
168
188
  for rule in HighlightRules.compare_rules.values():
169
- rule.apply(api_info, color_columns, summary_compare)
189
+ rule.apply(api_info, color_columns, dump_mode)
170
190
 
171
- highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
172
- highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
191
+ red_lines_num_set = {x[0] for x in red_lines}
192
+ yellow_lines_num_set = {x[0] for x in yellow_lines}
193
+ highlight_dict.get('red_rows', set()).update(red_lines_num_set)
194
+ highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
195
+ highlight_dict.get('red_lines', []).extend(red_lines)
196
+ highlight_dict.get('yellow_lines', []).extend(yellow_lines)
173
197
 
174
198
 
175
199
  def get_name_and_state(name):
176
200
  """Get api/module name and state"""
177
- if "input" in name:
178
- api_name = name.split("input")[0]
179
- state = "input"
201
+ if Const.INPUT in name:
202
+ api_name = name.split(Const.INPUT)[0]
203
+ state = Const.INPUT
180
204
  else:
181
- api_name = name.split("output")[0]
182
- state = "output"
205
+ api_name = name.split(Const.OUTPUT)[0]
206
+ state = Const.OUTPUT
183
207
  return api_name, state
184
208
 
185
209
 
186
- def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
210
+ def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
187
211
  """将dataframe根据API分组,并找到有误差的算子用于高亮"""
188
212
  result = result_df.values
189
213
  start, input_num, output_num, end = 0, 0, 0, len(result_df)
190
214
  last_api_name, last_state = None, None
191
215
  num, last_len = 0, 0
216
+ progress_bar = tqdm(total=len(result), desc="API/Module Analyse Progress", unit="item", ncols=100)
192
217
  for res_i in result:
193
- api_name, state = get_name_and_state(res_i[0])
218
+ api_full_name = safe_get_value(res_i, 0, "res_i")
219
+ api_name, state = get_name_and_state(api_full_name)
194
220
  if last_api_name:
195
221
  if api_name == last_api_name:
196
222
  if state == last_state:
@@ -201,29 +227,33 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m
201
227
  else:
202
228
  output_num = num
203
229
  find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
204
- summary_compare, md5_compare)
230
+ dump_mode)
205
231
  num, last_api_name, last_state = 1, api_name, state
206
232
  start += input_num + output_num
207
233
  input_num, output_num = 1, 0
208
234
  else:
209
235
  num, last_api_name, last_state = 1, api_name, state
236
+ progress_bar.update(1)
237
+ progress_bar.close()
210
238
  if state:
211
- if state == "input":
239
+ if state == Const.INPUT:
212
240
  input_num = num
213
241
  else:
214
242
  output_num = num
215
243
  find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
216
- summary_compare, md5_compare)
244
+ dump_mode)
217
245
 
218
246
 
219
247
  def highlight_rows_xlsx(result_df, highlight_dict, file_path):
220
248
  """Write and highlight results in Excel"""
221
- logger.info('Compare result is %s' % file_path)
249
+
250
+ update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
222
251
 
223
252
  wb = openpyxl.Workbook()
224
253
  ws = wb.active
225
254
 
226
255
  # write header
256
+ logger.info('Initializing Excel file.')
227
257
  for j, col_name in enumerate(result_df.columns, start=1):
228
258
  if not csv_value_is_valid(col_name):
229
259
  raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
@@ -231,20 +261,59 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
231
261
 
232
262
  for i, row in enumerate(result_df.iterrows(), start=2):
233
263
  for j, value in enumerate(row[1], start=1):
234
- if not isinstance(value, (float, int)):
264
+ if not isinstance(value, (float, int)) or isinstance(value, bool):
235
265
  value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
236
266
  if not csv_value_is_valid(value):
237
- raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: {file_path}.")
267
+ raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: "
268
+ f"{file_path}.")
238
269
  ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
270
+
271
+ # 对可疑数据标色
272
+ logger.info('Coloring Excel in progress.')
273
+ col_len = len(result_df.columns)
274
+ red_fill = PatternFill(
275
+ start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
276
+ )
277
+ yellow_fill = PatternFill(
278
+ start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
279
+ )
280
+ for i in highlight_dict.get("red_rows", []):
281
+ for j in range(1, col_len + 1):
282
+ ws.cell(row=i + 2, column=j).fill = red_fill
283
+ for i in highlight_dict.get("yellow_rows", []):
284
+ for j in range(1, col_len + 1):
285
+ ws.cell(row=i + 2, column=j).fill = yellow_fill
286
+ logger.info('Saving Excel file to disk: %s' % file_path)
287
+ save_workbook(wb, file_path)
239
288
 
240
- if (i - 2) in highlight_dict['red_rows']:
241
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
242
- end_color=CompareConst.RED, fill_type="solid")
243
- elif (i - 2) in highlight_dict['yellow_rows']:
244
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
245
- end_color=CompareConst.YELLOW, fill_type="solid")
246
289
 
247
- save_workbook(wb, file_path)
290
+ def update_highlight_err_msg(result_df, highlight_dict):
291
+ if result_df.shape[1] <= 1:
292
+ return
293
+
294
+ if CompareConst.NPU_MD5 in result_df.columns:
295
+ return
296
+
297
+ err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
298
+ red_lines_num_set = highlight_dict.get('red_rows')
299
+
300
+ for color in ['red', 'yellow']:
301
+ line_key = f'{color}_lines'
302
+ lines = highlight_dict.get(line_key, [])
303
+ for line_index, messages in lines:
304
+ if color == 'yellow' and line_index in red_lines_num_set:
305
+ continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
306
+
307
+ for msg in messages:
308
+ if err_msg[line_index] == '':
309
+ err_msg[line_index] = msg
310
+ else:
311
+ err_msg[line_index] += '\n' + msg
312
+
313
+ if color == 'red':
314
+ red_lines_num_set.add(line_index)
315
+
316
+ result_df[CompareConst.ERROR_MESSAGE] = err_msg
248
317
 
249
318
 
250
319
  def csv_value_is_valid(value: str) -> bool:
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.compare.layer_mapping.layer_mapping import (
17
+ generate_data_mapping_by_layer_mapping,
18
+ generate_api_mapping_by_layer_mapping,
19
+ )
@@ -0,0 +1,235 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ from copy import deepcopy
19
+ from dataclasses import dataclass
20
+ from typing import ClassVar, Dict, List, Optional, Tuple
21
+
22
+ import yaml
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import save_yaml
25
+ from msprobe.core.common.log import logger
26
+ from msprobe.core.common.utils import CompareException, add_time_with_yaml
27
+ from msprobe.core.compare.layer_mapping.postprocess_pass import postprocess_pass
28
+
29
+
30
+ @dataclass
31
+ class DumpDataItem:
32
+ framework: str
33
+ data_name: Optional[str] = None
34
+ api_type: Optional[str] = None
35
+ api_name: Optional[str] = None
36
+ type_name: Optional[str] = None
37
+ full_scope: str = ""
38
+ layer_scope: str = ""
39
+ stack_scope: str = ""
40
+ frame_stack_scope: str = ""
41
+ user_stack_scope: str = ""
42
+ construct_scope: str = ""
43
+ scope_direction: Optional[str] = None
44
+ scope_id: Optional[int] = None
45
+
46
+ # 类变量使用 ClassVar
47
+ framework2layername: ClassVar[Dict[str, str]] = {
48
+ Const.MS_FRAMEWORK: Const.CELL, Const.PT_FRAMEWORK: Const.MODULE}
49
+ framework2stack_sign: ClassVar[Dict[str, Tuple[str, str]]] = {
50
+ Const.MS_FRAMEWORK: ("Template", "construct"),
51
+ Const.PT_FRAMEWORK: ("Template", r"in (for|back)ward,")
52
+ }
53
+
54
+ @staticmethod
55
+ def check_stack_valid(stack_info):
56
+ if stack_info is not None:
57
+ if not isinstance(stack_info, list):
58
+ logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
59
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
60
+ for stack in stack_info:
61
+ if not isinstance(stack, str):
62
+ logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
63
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
64
+
65
+ def set(self, data_name: str, construct_info: str, stack_info: str) -> None:
66
+ self.set_name(data_name)
67
+ self.set_layer_scope(construct_info)
68
+ self.set_stack_scope(stack_info)
69
+ self.set_full_scope()
70
+
71
+ def set_name(self, data_name: str) -> None:
72
+ self.data_name = data_name
73
+ data_name_list = data_name.split(Const.SEP)
74
+ if not data_name_list or len(data_name_list) < abs(Const.LAYER_NAME_INDEX):
75
+ logger.error(
76
+ f"The dump data does not comply with the format specification and "
77
+ f"must contain no less than four fields. "
78
+ f"The current data is {data_name}"
79
+ )
80
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
81
+
82
+ self.api_type = data_name_list[Const.API_TYPE_INDEX]
83
+ self.type_name = data_name_list[Const.TYPE_NAME_INDEX]
84
+ if self.api_type == self.framework2layername.get(self.framework):
85
+ self.api_name = data_name_list[Const.LAYER_NAME_INDEX]
86
+ else:
87
+ self.api_name = self.type_name
88
+
89
+ def set_layer_scope(self, construct_info: str) -> None:
90
+ self.construct_scope = construct_info
91
+ if self.api_type == self.framework2layername.get(self.framework):
92
+ # remove api name
93
+ data_list = self.data_name.split(Const.SEP)
94
+ data_list = data_list[:Const.LAYER_NAME_INDEX] + data_list[Const.TYPE_NAME_INDEX:]
95
+ elif construct_info:
96
+ data_list = construct_info.split(Const.SEP)
97
+ else:
98
+ data_list = []
99
+
100
+ if data_list:
101
+ self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
102
+ else:
103
+ self.layer_scope = self.framework2layername.get(self.framework)
104
+ if construct_info:
105
+ construct_list = construct_info.split(Const.SEP)
106
+ if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
107
+ logger.error(
108
+ f"The construct data does not comply with the format specification and "
109
+ f"must contain no less than four fields. "
110
+ f"The current data is {construct_info}"
111
+ )
112
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
113
+ self.scope_id = construct_list[Const.SCOPE_ID_INDEX]
114
+ self.scope_direction = construct_list[Const.SCOPE_DIRECTION_INDEX]
115
+
116
+ def set_stack_scope(self, stack_info: str) -> None:
117
+ # Cell/Module has no stack info
118
+ if self.api_type == self.framework2layername.get(self.framework):
119
+ return
120
+
121
+ if self.api_type in Const.DATA_TYPE_SKIP_LIST or not stack_info:
122
+ return
123
+
124
+ start_sign, end_sign = self.framework2stack_sign.get(self.framework)
125
+ self.check_stack_valid(stack_info)
126
+ start_pos, end_pos = find_regard_scope(stack_info, start_sign, end_sign)
127
+ # 获取指定范围的代码
128
+ regard_scope = stack_info[start_pos + 1:end_pos]
129
+ frame_func_stack_list, user_func_stack_list = find_stack_func_list(regard_scope)
130
+ self.frame_stack_scope = Const.SEP.join(frame_func_stack_list)
131
+ self.user_stack_scope = Const.SEP.join(user_func_stack_list)
132
+
133
+ def set_full_scope(self, use_user_func_scope=False, use_frame_func_scope=True) -> None:
134
+ scope_list = [self.layer_scope]
135
+ if use_user_func_scope and self.user_stack_scope:
136
+ scope_list.append(self.user_stack_scope)
137
+ if use_frame_func_scope and self.frame_stack_scope:
138
+ scope_list.append(self.frame_stack_scope)
139
+ scope_list.append(self.api_name)
140
+ self.full_scope = Const.SEP.join(scope_list)
141
+
142
+
143
+ def find_regard_scope(lines, start_sign, end_sign):
144
+ # 找出 start_pos 和 end_pos
145
+ start_pos = -1
146
+ end_pos = len(lines)
147
+ for idx, ii in enumerate(lines):
148
+ if re.search(start_sign, ii):
149
+ start_pos = idx
150
+ elif start_pos >= 0 and re.search(end_sign, ii):
151
+ end_pos = idx
152
+ break
153
+ return start_pos, end_pos
154
+
155
+
156
+ def find_stack_func_list(lines, record_user=True):
157
+ res_list = []
158
+ user_stack = []
159
+ frame_stack = None
160
+ no_entrance = True
161
+ for line in lines:
162
+ ele_list = line.split(Const.COMMA)
163
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
164
+ # if framework func line and no framework entrance found yet
165
+ if any(ii in file_ele for ii in Const.FRAME_FILE_LIST) and no_entrance:
166
+ frame_stack = line # Update the last target index
167
+ else:
168
+ if record_user:
169
+ user_stack.append(line)
170
+ no_entrance = False
171
+
172
+ # Check if the last string in the list contains target str
173
+ if frame_stack and no_entrance:
174
+ no_entrance = False
175
+
176
+ # 过滤和处理 regard_scope
177
+ frame_func = get_stack_in_lines([frame_stack])
178
+ user_func = get_stack_in_lines(user_stack)
179
+ return (frame_func, user_func)
180
+
181
+
182
+ def get_stack_in_lines(simplified: List[str]):
183
+ res_list = []
184
+ if not simplified:
185
+ return res_list
186
+ for line in simplified:
187
+ if not line:
188
+ continue
189
+
190
+ ele_list = line.split(Const.COMMA)
191
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
192
+ if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
193
+ continue
194
+
195
+ func_ele = ele_list[Const.STACK_FUNC_INDEX]
196
+ if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
197
+ continue
198
+
199
+ in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
200
+
201
+ res_list.append(in_func_name)
202
+
203
+ reversed_list = res_list[::-1]
204
+ return reversed_list
205
+
206
+
207
+ def dumpdata_representer(dumper, data):
208
+ d = deepcopy(data.__dict__)
209
+ d.pop("data_name")
210
+ return dumper.represent_dict(d)
211
+
212
+
213
+ def get_dump_data_items(dump, stack, construct, framework, output_path=None):
214
+ if not stack or not construct:
215
+ return []
216
+ name2item = {}
217
+ data_items = []
218
+
219
+ dump_data = dump.get("data", {})
220
+ for data_name in dump_data:
221
+ code_info = stack.get(data_name, None)
222
+ parent_info = construct.get(data_name, None)
223
+ data_item = DumpDataItem(framework)
224
+ data_item.set(data_name, parent_info, code_info)
225
+ name2item[data_name] = data_item
226
+ data_items.append(data_item)
227
+
228
+ postprocess_pass(data_items, name2item)
229
+
230
+ if output_path:
231
+ yaml.add_representer(DumpDataItem, dumpdata_representer)
232
+ file_name = add_time_with_yaml(f"{framework}_data")
233
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
234
+ save_yaml(file_path, name2item)
235
+ return data_items