mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,96 +13,129 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import math
17
16
  import abc
17
+ import math
18
+ import multiprocessing
18
19
  import re
19
20
  from collections import namedtuple
21
+
20
22
  import numpy as np
21
23
  import openpyxl
22
24
  from openpyxl.styles import PatternFill
23
- from msprobe.core.common.utils import get_header_index
25
+ from openpyxl.utils.dataframe import dataframe_to_rows
26
+ from tqdm import tqdm
27
+
28
+ from msprobe.core.common.const import CompareConst, Const
24
29
  from msprobe.core.common.file_utils import save_workbook
25
30
  from msprobe.core.common.log import logger
26
- from msprobe.core.common.const import CompareConst, FileCheckConst
31
+ from msprobe.core.common.utils import get_header_index, safe_get_value
32
+ from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
27
33
 
28
34
 
29
35
  class HighlightCheck(abc.ABC):
30
36
  @abc.abstractmethod
31
- def apply(self, info, color_columns, summary_compare):
37
+ def apply(self, info, color_columns, dump_mode):
32
38
  raise NotImplementedError
33
39
 
34
40
 
41
+ def add_highlight_row_info(color_list, num, highlight_err_msg):
42
+ for i, (existing_num, existing_err_msg) in enumerate(color_list):
43
+ if num == existing_num:
44
+ color_list[i][1].append(highlight_err_msg)
45
+ return
46
+ color_list.append((num, [highlight_err_msg]))
47
+
48
+
35
49
  class CheckOrderMagnitude(HighlightCheck):
36
50
  """检查Max diff的数量级差异"""
37
- def apply(self, info, color_columns, summary_compare=True):
51
+
52
+ def apply(self, info, color_columns, dump_mode):
38
53
  api_in, api_out, num = info
39
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
54
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
55
+ else CompareConst.MAX_ABS_ERR, dump_mode)
40
56
  if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
41
57
  return
42
58
  in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
43
59
  out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
44
60
  if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
45
- color_columns.yellow.append(num)
61
+ add_highlight_row_info(color_columns.yellow, num,
62
+ "maximum absolute error of both input/parameters and output exceed 1, "
63
+ "with the output larger by an order of magnitude")
46
64
 
47
65
 
48
66
  class CheckOneThousandErrorRatio(HighlightCheck):
49
67
  """检查千分误差比率"""
50
- def apply(self, info, color_columns, summary_compare=True):
68
+
69
+ def apply(self, info, color_columns, dump_mode):
51
70
  api_in, api_out, num = info
52
- one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
71
+ one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
53
72
  if (not isinstance(api_in[one_thousand_index], (float, int)) or
54
73
  not isinstance(api_out[one_thousand_index], (float, int))):
55
74
  return
56
75
  if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
57
76
  api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
58
- color_columns.red.append(num)
77
+ add_highlight_row_info(color_columns.red, num,
78
+ "The input/parameters's one thousandth err ratio exceeds 0.9, "
79
+ "while the output's is below 0.6")
59
80
  elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
60
- color_columns.yellow.append(num)
81
+ add_highlight_row_info(color_columns.yellow, num,
82
+ "The output's one thousandth err ratio decreases by more than 0.1 "
83
+ "compared to the input/parameters's")
61
84
 
62
85
 
63
86
  class CheckCosineSimilarity(HighlightCheck):
64
87
  """检查余弦相似度"""
65
- def apply(self, info, color_columns, summary_compare=True):
88
+
89
+ def apply(self, info, color_columns, dump_mode):
66
90
  api_in, api_out, num = info
67
- cosine_index = get_header_index('Cosine', summary_compare)
91
+ cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
68
92
  if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
69
93
  return
70
94
  if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
71
- color_columns.yellow.append(num)
95
+ add_highlight_row_info(color_columns.yellow, num,
96
+ "The output's cosine decreases by more than 0.1 "
97
+ "compared to the input/parameters's")
72
98
 
73
99
 
74
100
  class CheckMaxRelativeDiff(HighlightCheck):
75
101
  """检查最大相对差异"""
76
- def apply(self, info, color_columns, summary_compare=True):
102
+
103
+ def apply(self, info, color_columns, dump_mode):
77
104
  api_in, api_out, num = info
78
- max_diff_index = get_header_index('Max diff', summary_compare)
79
- bench_max_index = get_header_index('Bench max', summary_compare)
80
- input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
81
- output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
105
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
106
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
107
+ input_max_relative_diff = np.abs(
108
+ np.divide(api_in[max_diff_index], max(Const.FLOAT_EPSILON, api_in[bench_max_index])))
109
+ output_max_relative_diff = np.abs(
110
+ np.divide(api_out[max_diff_index], max(Const.FLOAT_EPSILON, api_out[bench_max_index])))
82
111
  if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
83
112
  (float, int)):
84
113
  return
85
114
  if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
86
- color_columns.red.append(num)
115
+ add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
87
116
  elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
88
117
  input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
89
- color_columns.yellow.append(num)
118
+ add_highlight_row_info(color_columns.yellow, num,
119
+ "The output's maximum relative error exceeds 0.1, "
120
+ "while the input/parameters's is below 0.01")
90
121
 
91
122
 
92
123
  class CheckOverflow(HighlightCheck):
93
124
  """检查是否存在溢出"""
94
- def apply(self, info, color_columns, summary_compare=True):
125
+
126
+ def apply(self, info, color_columns, dump_mode):
95
127
  line, num = info
96
- npu_max_index = get_header_index('NPU max', summary_compare)
97
- npu_min_index = get_header_index('NPU min', summary_compare)
98
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
128
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
129
+ npu_min_index = get_header_index(CompareConst.NPU_MIN, dump_mode)
130
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
131
+ else CompareConst.MAX_ABS_ERR, dump_mode)
99
132
  if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
100
133
  line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
101
- color_columns.red.append(num)
134
+ add_highlight_row_info(color_columns.red, num, "maximum or minimum is nan, -inf, or inf")
102
135
  return
103
136
  # check if Max_Diff > 1e+10
104
- if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
105
- color_columns.red.append(num)
137
+ if isinstance(line[max_diff_index], (float, int)) and abs(line[max_diff_index]) > CompareConst.MAX_DIFF_RED:
138
+ add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
106
139
 
107
140
 
108
141
  class HighlightRules:
@@ -122,15 +155,31 @@ class HighlightRules:
122
155
  "check_order_magnitude": CheckOrderMagnitude(),
123
156
  "check_max_relative_diff": CheckMaxRelativeDiff(),
124
157
  }
125
-
126
158
 
127
- def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
159
+
160
+ def check_indices_numeric(api_items, indices: list):
161
+ """检查指定索引处的值是否都为数字类型(int 或 float)"""
162
+ return all(isinstance(api_items[i], (float, int)) for i in indices)
163
+
164
+
165
+ def apply_comparison_rules(api_info, dump_mode, color_columns):
166
+ """output与input/params的比较"""
167
+ if dump_mode == Const.SUMMARY:
168
+ for rule in HighlightRules.summary_compare_rules.values():
169
+ rule.apply(api_info, color_columns, dump_mode)
170
+ else:
171
+ for rule in HighlightRules.compare_rules.values():
172
+ rule.apply(api_info, color_columns, dump_mode)
173
+
174
+
175
+ def find_error_rows(result, api_batch, highlight_dict, dump_mode):
128
176
  """找到单个API中需要高亮的行"""
129
- if md5_compare:
177
+ if dump_mode == Const.MD5:
130
178
  return
131
- npu_max_index = get_header_index('NPU max', summary_compare)
132
- bench_max_index = get_header_index('Bench max', summary_compare)
133
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
179
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
180
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
181
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
182
+ else CompareConst.MAX_ABS_ERR, dump_mode)
134
183
 
135
184
  red_lines, yellow_lines = [], []
136
185
  LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
@@ -138,122 +187,229 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
138
187
  ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
139
188
  color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
140
189
 
190
+ api_batch_start = api_batch.start # result_df的input起始全局索引
191
+ api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
192
+ api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
193
+ api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
194
+ api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
195
+
141
196
  # 对单行API的输入或输出进行误差判断
142
197
  for i, line in enumerate(result):
143
- num = last_len + i
144
- line_info = LineInfo(line_data=line, num_pointer=num)
198
+ index = api_batch_start + i
199
+ line_info = LineInfo(line_data=line, num_pointer=index)
145
200
  for rule in HighlightRules.basic_rules.values():
146
- rule.apply(line_info, color_columns, summary_compare)
201
+ rule.apply(line_info, color_columns, dump_mode)
147
202
 
148
203
  # 对API的输出与输入比较,进行误差判断
149
- for n, api_out in enumerate(result[n_num_input:len(result)]):
150
- num = last_len + n_num_input + n
151
- if num in red_lines:
204
+ for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
205
+ index = api_batch_start + api_batch_params_slice_index_local + n
206
+ # 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
207
+ if index in red_lines:
152
208
  continue
153
- if not isinstance(api_out[npu_max_index], (float, int)) \
154
- or not isinstance(api_out[bench_max_index], (float, int)) \
155
- or not isinstance(api_out[max_diff_index], (float, int)):
209
+ if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
156
210
  continue
157
- for _, api_in in enumerate(result[0:n_num_input]):
158
- if not isinstance(api_in[npu_max_index], (float, int)) \
159
- or not isinstance(api_in[bench_max_index], (float, int)) \
160
- or not isinstance(api_in[max_diff_index], (float, int)):
211
+
212
+ # input/parameters的比较检查, 这里api_in包括input、parameters
213
+ for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
214
+ if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
161
215
  continue
216
+ api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
217
+ apply_comparison_rules(api_info, dump_mode, color_columns)
218
+
219
+ red_lines_num_set = {x[0] for x in red_lines}
220
+ yellow_lines_num_set = {x[0] for x in yellow_lines}
221
+ highlight_dict.get('red_rows', set()).update(red_lines_num_set)
222
+ highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
223
+ highlight_dict.get('red_lines', []).extend(red_lines)
224
+ highlight_dict.get('yellow_lines', []).extend(yellow_lines)
225
+
226
+
227
+ class ApiBatch:
228
+ def __init__(self, api_name: str, start: int):
229
+ self.api_name = api_name
230
+ self.start = start
231
+ self.input_len = 1 # input的数量
232
+ self.params_end_index = start + 1 # params的结束index
233
+ self.output_end_index = start + 1 # output的结束index
234
+ self.params_grad_end_index = start + 1 # params_grad的结束index
235
+ # 内部state的标志("input", "output", "parameters", "parameters_grad"),
236
+ # 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index
237
+ self._state = Const.INPUT # api_batch初始化为input
238
+
239
+ def set_state(self, state: str):
240
+ """设置当前状态"""
241
+ if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}:
242
+ self._state = state
243
+ else:
244
+ raise ValueError(f"Invalid state: {state}")
245
+
246
+ def increment(self, state: str):
247
+ self.set_state(state)
248
+ if self._state == Const.INPUT or self._state == Const.KWARGS:
249
+ self.input_len += 1
250
+ self.params_end_index += 1
251
+ self.output_end_index += 1
252
+ if self._state == Const.PARAMS:
253
+ self.params_end_index += 1
254
+ self.output_end_index += 1
255
+ if self._state == Const.OUTPUT:
256
+ self.output_end_index += 1
257
+ self.params_grad_end_index += 1
258
+
259
+
260
+ def api_batches_update(api_batches, api_name, state, index):
261
+ """
262
+ 当一个api的所有item更新完后,input, output的索引范围:
263
+ input: [start: start+input_len]
264
+ output: [start+input_len: output_end_index]
265
+ params: [output_end_index: params_end_index]
266
+ """
267
+ if not api_batches:
268
+ api_batches.append(ApiBatch(api_name, index))
269
+ else:
270
+ api_batch = api_batches[-1]
271
+ if api_batch.api_name == api_name or (
272
+ not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
273
+ try:
274
+ api_batch.increment(state)
275
+ except ValueError as e:
276
+ logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
277
+ raise CompareException(CompareException.INVALID_STATE_ERROR) from e
278
+ else:
279
+ api_batches.append(ApiBatch(api_name, index))
162
280
 
163
- api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
164
- if summary_compare:
165
- for rule in HighlightRules.summary_compare_rules.values():
166
- rule.apply(api_info, color_columns, summary_compare)
167
- else:
168
- for rule in HighlightRules.compare_rules.values():
169
- rule.apply(api_info, color_columns, summary_compare)
170
281
 
171
- highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
172
- highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
282
+ def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
283
+ """将dataframe根据API分组,并找到有误差的算子用于高亮"""
284
+ result = result_df.values
285
+ api_batches = []
286
+ for i, res_i in enumerate(result):
287
+ api_full_name = safe_get_value(res_i, 0, "res_i")
288
+ api_name, state = get_name_and_state(api_full_name)
289
+ api_batches_update(api_batches, api_name, state, i)
290
+ with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
291
+ for api_batch in api_batches:
292
+ find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
293
+ dump_mode)
294
+ progress_bar.update(1)
295
+
296
+
297
+ def value_check(value, api_name=None, i=None, result_df_columns=None):
298
+ if not table_value_is_valid(value):
299
+ if result_df_columns:
300
+ logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
301
+ f"is not allowed to be written into the compare result xlsx.")
302
+ else:
303
+ logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
304
+
305
+
306
+ def df_malicious_value_check(df_chunk, result_df_columns):
307
+ for row in df_chunk.itertuples(index=False):
308
+ api_name = row[0]
309
+ for i, value in enumerate(row):
310
+ value_check(value, api_name, i, result_df_columns)
173
311
 
174
312
 
175
- def get_name_and_state(name):
176
- """Get api/module name and state"""
177
- if "input" in name:
178
- api_name = name.split("input")[0]
179
- state = "input"
313
+ def handle_multi_process_malicious_value_check(func, result_df):
314
+ result_total_nums = len(result_df)
315
+ process_num = int((multiprocessing.cpu_count() + 1) / 2)
316
+
317
+ if result_total_nums <= process_num:
318
+ process_num = 1
319
+ chunks = [result_df]
180
320
  else:
181
- api_name = name.split("output")[0]
182
- state = "output"
183
- return api_name, state
321
+ chunk_size = result_total_nums // process_num
322
+ chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
184
323
 
324
+ pool = multiprocessing.Pool(process_num)
325
+
326
+ def err_call(args):
327
+ logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
328
+ try:
329
+ pool.terminate()
330
+ except OSError:
331
+ logger.error("Pool terminate failed")
332
+
333
+ result_df_columns = result_df.columns.tolist()
334
+ for column in result_df_columns:
335
+ value_check(column)
336
+ for df_chunk in chunks:
337
+ pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
338
+
339
+ pool.close()
340
+ pool.join()
185
341
 
186
- def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
187
- """将dataframe根据API分组,并找到有误差的算子用于高亮"""
188
- result = result_df.values
189
- start, input_num, output_num, end = 0, 0, 0, len(result_df)
190
- last_api_name, last_state = None, None
191
- num, last_len = 0, 0
192
- for res_i in result:
193
- api_name, state = get_name_and_state(res_i[0])
194
- if last_api_name:
195
- if api_name == last_api_name:
196
- if state == last_state:
197
- num += 1
198
- else:
199
- input_num = num
200
- num, last_state = 1, state
201
- else:
202
- output_num = num
203
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
204
- summary_compare, md5_compare)
205
- num, last_api_name, last_state = 1, api_name, state
206
- start += input_num + output_num
207
- input_num, output_num = 1, 0
208
- else:
209
- num, last_api_name, last_state = 1, api_name, state
210
- if state:
211
- if state == "input":
212
- input_num = num
213
- else:
214
- output_num = num
215
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
216
- summary_compare, md5_compare)
342
+
343
+ def compare_result_df_convert(value):
344
+ if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
345
+ value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
346
+ if isinstance(value, float):
347
+ value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
348
+ return value
217
349
 
218
350
 
219
351
  def highlight_rows_xlsx(result_df, highlight_dict, file_path):
220
352
  """Write and highlight results in Excel"""
221
- logger.info('Compare result is %s' % file_path)
353
+
354
+ update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
222
355
 
223
356
  wb = openpyxl.Workbook()
224
357
  ws = wb.active
225
358
 
226
359
  # write header
227
- for j, col_name in enumerate(result_df.columns, start=1):
228
- if not csv_value_is_valid(col_name):
229
- raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
230
- ws.cell(row=1, column=j, value=col_name)
231
-
232
- for i, row in enumerate(result_df.iterrows(), start=2):
233
- for j, value in enumerate(row[1], start=1):
234
- if not isinstance(value, (float, int)):
235
- value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
236
- if not csv_value_is_valid(value):
237
- raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: {file_path}.")
238
- ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
239
-
240
- if (i - 2) in highlight_dict['red_rows']:
241
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
242
- end_color=CompareConst.RED, fill_type="solid")
243
- elif (i - 2) in highlight_dict['yellow_rows']:
244
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
245
- end_color=CompareConst.YELLOW, fill_type="solid")
246
-
360
+ logger.info('Initializing Excel file.')
361
+
362
+ handle_multi_process_malicious_value_check(df_malicious_value_check, result_df)
363
+
364
+ result_df_convert = result_df.applymap(compare_result_df_convert)
365
+
366
+ for row in dataframe_to_rows(result_df_convert, index=False, header=True):
367
+ ws.append(row)
368
+
369
+ # 对可疑数据标色
370
+ logger.info('Coloring Excel in progress.')
371
+ col_len = len(result_df.columns)
372
+ red_fill = PatternFill(
373
+ start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
374
+ )
375
+ yellow_fill = PatternFill(
376
+ start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
377
+ )
378
+ for i in highlight_dict.get("red_rows", []):
379
+ for j in range(1, col_len + 1):
380
+ ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
381
+ for i in highlight_dict.get("yellow_rows", []):
382
+ for j in range(1, col_len + 1):
383
+ ws.cell(row=i + 2, column=j).fill = yellow_fill
384
+
385
+ logger.info('Saving Excel file to disk: %s' % file_path)
247
386
  save_workbook(wb, file_path)
248
387
 
249
388
 
250
- def csv_value_is_valid(value: str) -> bool:
251
- if not isinstance(value, str):
252
- return True
253
- try:
254
- # -1.00 or +1.00 should be consdiered as digit numbers
255
- float(value)
256
- except ValueError:
257
- # otherwise, they will be considered as formular injections
258
- return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
259
- return True
389
+ def update_highlight_err_msg(result_df, highlight_dict):
390
+ if result_df.shape[1] <= 1:
391
+ return
392
+
393
+ if CompareConst.NPU_MD5 in result_df.columns:
394
+ return
395
+
396
+ err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
397
+ red_lines_num_set = highlight_dict.get('red_rows')
398
+
399
+ for color in ['red', 'yellow']:
400
+ line_key = f'{color}_lines'
401
+ lines = highlight_dict.get(line_key, [])
402
+ for line_index, messages in lines:
403
+ if color == 'yellow' and line_index in red_lines_num_set:
404
+ continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
405
+
406
+ for msg in messages:
407
+ if err_msg[line_index] == '':
408
+ err_msg[line_index] = msg
409
+ else:
410
+ err_msg[line_index] += '\n' + msg
411
+
412
+ if color == 'red':
413
+ red_lines_num_set.add(line_index)
414
+
415
+ result_df[CompareConst.ERROR_MESSAGE] = err_msg
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.compare.layer_mapping.layer_mapping import (
17
+ generate_data_mapping_by_layer_mapping,
18
+ generate_api_mapping_by_layer_mapping,
19
+ )