mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,222 +1,223 @@
1
- import math
2
- import abc
3
- from collections import namedtuple
4
- import numpy as np
5
- import openpyxl
6
- from openpyxl.styles import PatternFill
7
- from msprobe.core.common.utils import get_header_index, save_workbook
8
- from msprobe.core.common.log import logger
9
- from msprobe.core.common.const import CompareConst
10
-
11
-
12
- class HighlightCheck(abc.ABC):
13
- @abc.abstractmethod
14
- def apply(self, info, color_columns, summary_compare):
15
- raise NotImplementedError
16
-
17
-
18
- class CheckOrderMagnitude(HighlightCheck):
19
- """检查Max diff的数量级差异"""
20
- def apply(self, info, color_columns, summary_compare=True):
21
- api_in, api_out, num = info
22
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
23
- if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
24
- return
25
- in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
26
- out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
27
- if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
28
- color_columns.yellow.append(num)
29
-
30
-
31
- class CheckOneThousandErrorRatio(HighlightCheck):
32
- """检查千分误差比率"""
33
- def apply(self, info, color_columns, summary_compare=True):
34
- api_in, api_out, num = info
35
- one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
36
- if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
37
- return
38
- if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
39
- color_columns.red.append(num)
40
- elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
41
- color_columns.yellow.append(num)
42
-
43
-
44
- class CheckCosineSimilarity(HighlightCheck):
45
- """检查余弦相似度"""
46
- def apply(self, info, color_columns, summary_compare=True):
47
- api_in, api_out, num = info
48
- cosine_index = get_header_index('Cosine', summary_compare)
49
- if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
50
- return
51
- if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
52
- color_columns.yellow.append(num)
53
-
54
-
55
- class CheckMaxRelativeDiff(HighlightCheck):
56
- """检查最大相对差异"""
57
- def apply(self, info, color_columns, summary_compare=True):
58
- api_in, api_out, num = info
59
- max_diff_index = get_header_index('Max diff', summary_compare)
60
- bench_max_index = get_header_index('Bench max', summary_compare)
61
- input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
62
- output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
63
- if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
64
- (float, int)):
65
- return
66
- if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
67
- color_columns.red.append(num)
68
- elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
69
- color_columns.yellow.append(num)
70
-
71
-
72
- class CheckOverflow(HighlightCheck):
73
- """检查是否存在溢出"""
74
- def apply(self, info, color_columns, summary_compare=True):
75
- line, num = info
76
- npu_max_index = get_header_index('NPU max', summary_compare)
77
- npu_min_index = get_header_index('NPU min', summary_compare)
78
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
79
- if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
80
- line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
81
- color_columns.red.append(num)
82
- return
83
- # check if Max_Diff > 1e+10
84
- if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
85
- color_columns.red.append(num)
86
-
87
-
88
- class HighlightRules:
89
- """高亮规则集合,用于检查API的误差"""
90
- # 适用于每行的规则
91
- basic_rules = {
92
- "check_overflow": CheckOverflow()
93
- }
94
-
95
- # 用于比较输入和输出的规则
96
- compare_rules = {
97
- "check_order_magnitude": CheckOrderMagnitude(),
98
- "check_one_thousand_error": CheckOneThousandErrorRatio(),
99
- "check_cosine_similarity": CheckCosineSimilarity()
100
- }
101
- summary_compare_rules = {
102
- "check_order_magnitude": CheckOrderMagnitude(),
103
- "check_max_relative_diff": CheckMaxRelativeDiff(),
104
- }
105
-
106
-
107
- def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
108
- """找到单个API中需要高亮的行"""
109
- if md5_compare:
110
- return
111
- npu_max_index = get_header_index('NPU max', summary_compare)
112
- bench_max_index = get_header_index('Bench max', summary_compare)
113
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
114
-
115
- red_lines, yellow_lines = [], []
116
- LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
117
- ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
118
- ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
119
- color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
120
-
121
- # 对单行API的输入或输出进行误差判断
122
- for i, line in enumerate(result):
123
- num = last_len + i
124
- line_info = LineInfo(line_data=line, num_pointer=num)
125
- for rule in HighlightRules.basic_rules.values():
126
- rule.apply(line_info, color_columns, summary_compare)
127
-
128
- # 对API的输出与输入比较,进行误差判断
129
- for n, api_out in enumerate(result[n_num_input:len(result)]):
130
- num = last_len + n_num_input + n
131
- if num in red_lines:
132
- continue
133
- if not isinstance(api_out[npu_max_index], (float, int)) \
134
- or not isinstance(api_out[bench_max_index], (float, int)) \
135
- or not isinstance(api_out[max_diff_index], (float, int)):
136
- continue
137
- for _, api_in in enumerate(result[0:n_num_input]):
138
- if not isinstance(api_in[npu_max_index], (float, int)) \
139
- or not isinstance(api_in[bench_max_index], (float, int)) \
140
- or not isinstance(api_in[max_diff_index], (float, int)):
141
- continue
142
-
143
- api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
144
- if summary_compare:
145
- for rule in HighlightRules.summary_compare_rules.values():
146
- rule.apply(api_info, color_columns, summary_compare)
147
- else:
148
- for rule in HighlightRules.compare_rules.values():
149
- rule.apply(api_info, color_columns, summary_compare)
150
-
151
- highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
152
- highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
153
-
154
-
155
- def get_name_and_state(name):
156
- """Get api/module name and state"""
157
- if "input" in name:
158
- api_name = name.split("input")[0]
159
- state = "input"
160
- else:
161
- api_name = name.split("output")[0]
162
- state = "output"
163
- return api_name, state
164
-
165
-
166
- def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
167
- """将dataframe根据API分组,并找到有误差的算子用于高亮"""
168
- result = result_df.values
169
- start, input_num, output_num, end = 0, 0, 0, len(result_df)
170
- last_api_name, last_state = None, None
171
- num, last_len = 0, 0
172
- for res_i in result:
173
- api_name, state = get_name_and_state(res_i[0])
174
- if last_api_name:
175
- if api_name == last_api_name:
176
- if state == last_state:
177
- num += 1
178
- else:
179
- input_num = num
180
- num, last_state = 1, state
181
- else:
182
- output_num = num
183
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
184
- summary_compare, md5_compare)
185
- num, last_api_name, last_state = 1, api_name, state
186
- start += input_num + output_num
187
- input_num, output_num = 1, 0
188
- else:
189
- num, last_api_name, last_state = 1, api_name, state
190
- if state:
191
- if state == "input":
192
- input_num = num
193
- else:
194
- output_num = num
195
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
196
-
197
-
198
- def highlight_rows_xlsx(result_df, highlight_dict, file_path):
199
- """Write and highlight results in Excel"""
200
- logger.info('Compare result is %s' % file_path)
201
-
202
- wb = openpyxl.Workbook()
203
- ws = wb.active
204
-
205
- # write header
206
- for j, col_name in enumerate(result_df.columns, start=1):
207
- ws.cell(row=1, column=j, value=col_name)
208
-
209
- for i, row in enumerate(result_df.iterrows(), start=2):
210
- for j, value in enumerate(row[1], start=1):
211
- if not isinstance(value, (float, int)):
212
- value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
213
- ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
214
-
215
- if (i - 2) in highlight_dict['red_rows']:
216
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
217
- end_color=CompareConst.RED, fill_type="solid")
218
- elif (i - 2) in highlight_dict['yellow_rows']:
219
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
220
- end_color=CompareConst.YELLOW, fill_type="solid")
221
-
222
- save_workbook(wb, file_path)
1
+ import math
2
+ import abc
3
+ from collections import namedtuple
4
+ import numpy as np
5
+ import openpyxl
6
+ from openpyxl.styles import PatternFill
7
+ from msprobe.core.common.utils import get_header_index
8
+ from msprobe.core.common.file_utils import save_workbook
9
+ from msprobe.core.common.log import logger
10
+ from msprobe.core.common.const import CompareConst
11
+
12
+
13
+ class HighlightCheck(abc.ABC):
14
+ @abc.abstractmethod
15
+ def apply(self, info, color_columns, summary_compare):
16
+ raise NotImplementedError
17
+
18
+
19
+ class CheckOrderMagnitude(HighlightCheck):
20
+ """检查Max diff的数量级差异"""
21
+ def apply(self, info, color_columns, summary_compare=True):
22
+ api_in, api_out, num = info
23
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
24
+ if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
25
+ return
26
+ in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
27
+ out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
28
+ if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
29
+ color_columns.yellow.append(num)
30
+
31
+
32
+ class CheckOneThousandErrorRatio(HighlightCheck):
33
+ """检查千分误差比率"""
34
+ def apply(self, info, color_columns, summary_compare=True):
35
+ api_in, api_out, num = info
36
+ one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
37
+ if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
38
+ return
39
+ if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
40
+ color_columns.red.append(num)
41
+ elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
42
+ color_columns.yellow.append(num)
43
+
44
+
45
+ class CheckCosineSimilarity(HighlightCheck):
46
+ """检查余弦相似度"""
47
+ def apply(self, info, color_columns, summary_compare=True):
48
+ api_in, api_out, num = info
49
+ cosine_index = get_header_index('Cosine', summary_compare)
50
+ if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
51
+ return
52
+ if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
53
+ color_columns.yellow.append(num)
54
+
55
+
56
+ class CheckMaxRelativeDiff(HighlightCheck):
57
+ """检查最大相对差异"""
58
+ def apply(self, info, color_columns, summary_compare=True):
59
+ api_in, api_out, num = info
60
+ max_diff_index = get_header_index('Max diff', summary_compare)
61
+ bench_max_index = get_header_index('Bench max', summary_compare)
62
+ input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
63
+ output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
64
+ if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
65
+ (float, int)):
66
+ return
67
+ if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
68
+ color_columns.red.append(num)
69
+ elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
70
+ color_columns.yellow.append(num)
71
+
72
+
73
+ class CheckOverflow(HighlightCheck):
74
+ """检查是否存在溢出"""
75
+ def apply(self, info, color_columns, summary_compare=True):
76
+ line, num = info
77
+ npu_max_index = get_header_index('NPU max', summary_compare)
78
+ npu_min_index = get_header_index('NPU min', summary_compare)
79
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
80
+ if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
81
+ line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
82
+ color_columns.red.append(num)
83
+ return
84
+ # check if Max_Diff > 1e+10
85
+ if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
86
+ color_columns.red.append(num)
87
+
88
+
89
+ class HighlightRules:
90
+ """高亮规则集合,用于检查API的误差"""
91
+ # 适用于每行的规则
92
+ basic_rules = {
93
+ "check_overflow": CheckOverflow()
94
+ }
95
+
96
+ # 用于比较输入和输出的规则
97
+ compare_rules = {
98
+ "check_order_magnitude": CheckOrderMagnitude(),
99
+ "check_one_thousand_error": CheckOneThousandErrorRatio(),
100
+ "check_cosine_similarity": CheckCosineSimilarity()
101
+ }
102
+ summary_compare_rules = {
103
+ "check_order_magnitude": CheckOrderMagnitude(),
104
+ "check_max_relative_diff": CheckMaxRelativeDiff(),
105
+ }
106
+
107
+
108
+ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
109
+ """找到单个API中需要高亮的行"""
110
+ if md5_compare:
111
+ return
112
+ npu_max_index = get_header_index('NPU max', summary_compare)
113
+ bench_max_index = get_header_index('Bench max', summary_compare)
114
+ max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
115
+
116
+ red_lines, yellow_lines = [], []
117
+ LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
118
+ ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
119
+ ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
120
+ color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
121
+
122
+ # 对单行API的输入或输出进行误差判断
123
+ for i, line in enumerate(result):
124
+ num = last_len + i
125
+ line_info = LineInfo(line_data=line, num_pointer=num)
126
+ for rule in HighlightRules.basic_rules.values():
127
+ rule.apply(line_info, color_columns, summary_compare)
128
+
129
+ # 对API的输出与输入比较,进行误差判断
130
+ for n, api_out in enumerate(result[n_num_input:len(result)]):
131
+ num = last_len + n_num_input + n
132
+ if num in red_lines:
133
+ continue
134
+ if not isinstance(api_out[npu_max_index], (float, int)) \
135
+ or not isinstance(api_out[bench_max_index], (float, int)) \
136
+ or not isinstance(api_out[max_diff_index], (float, int)):
137
+ continue
138
+ for _, api_in in enumerate(result[0:n_num_input]):
139
+ if not isinstance(api_in[npu_max_index], (float, int)) \
140
+ or not isinstance(api_in[bench_max_index], (float, int)) \
141
+ or not isinstance(api_in[max_diff_index], (float, int)):
142
+ continue
143
+
144
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
145
+ if summary_compare:
146
+ for rule in HighlightRules.summary_compare_rules.values():
147
+ rule.apply(api_info, color_columns, summary_compare)
148
+ else:
149
+ for rule in HighlightRules.compare_rules.values():
150
+ rule.apply(api_info, color_columns, summary_compare)
151
+
152
+ highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
153
+ highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
154
+
155
+
156
+ def get_name_and_state(name):
157
+ """Get api/module name and state"""
158
+ if "input" in name:
159
+ api_name = name.split("input")[0]
160
+ state = "input"
161
+ else:
162
+ api_name = name.split("output")[0]
163
+ state = "output"
164
+ return api_name, state
165
+
166
+
167
+ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
168
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
169
+ result = result_df.values
170
+ start, input_num, output_num, end = 0, 0, 0, len(result_df)
171
+ last_api_name, last_state = None, None
172
+ num, last_len = 0, 0
173
+ for res_i in result:
174
+ api_name, state = get_name_and_state(res_i[0])
175
+ if last_api_name:
176
+ if api_name == last_api_name:
177
+ if state == last_state:
178
+ num += 1
179
+ else:
180
+ input_num = num
181
+ num, last_state = 1, state
182
+ else:
183
+ output_num = num
184
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
185
+ summary_compare, md5_compare)
186
+ num, last_api_name, last_state = 1, api_name, state
187
+ start += input_num + output_num
188
+ input_num, output_num = 1, 0
189
+ else:
190
+ num, last_api_name, last_state = 1, api_name, state
191
+ if state:
192
+ if state == "input":
193
+ input_num = num
194
+ else:
195
+ output_num = num
196
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
197
+
198
+
199
+ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
200
+ """Write and highlight results in Excel"""
201
+ logger.info('Compare result is %s' % file_path)
202
+
203
+ wb = openpyxl.Workbook()
204
+ ws = wb.active
205
+
206
+ # write header
207
+ for j, col_name in enumerate(result_df.columns, start=1):
208
+ ws.cell(row=1, column=j, value=col_name)
209
+
210
+ for i, row in enumerate(result_df.iterrows(), start=2):
211
+ for j, value in enumerate(row[1], start=1):
212
+ if not isinstance(value, (float, int)):
213
+ value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
214
+ ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
215
+
216
+ if (i - 2) in highlight_dict['red_rows']:
217
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
218
+ end_color=CompareConst.RED, fill_type="solid")
219
+ elif (i - 2) in highlight_dict['yellow_rows']:
220
+ ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
221
+ end_color=CompareConst.YELLOW, fill_type="solid")
222
+
223
+ save_workbook(wb, file_path)