mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,5 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.log import logger
2
- from msprobe.core.compare.utils import rename_api
17
+ from msprobe.core.compare.utils import rename_api
18
+ from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
19
+ from msprobe.core.common.const import Const
3
20
 
4
21
 
5
22
  dtype_mapping = {
@@ -18,24 +35,28 @@ dtype_mapping = {
18
35
  "BFloat16": "torch.bfloat16",
19
36
  "Complex64": "torch.complex64",
20
37
  "Complex128": "torch.complex128"
21
- }
38
+ }
22
39
 
23
40
 
24
- def check_struct_match(npu_dict, bench_dict, cross_frame=False):
41
+ def check_struct_match(npu_dict, bench_dict):
25
42
  npu_struct_in = npu_dict.get("input_struct")
26
43
  bench_struct_in = bench_dict.get("input_struct")
27
44
  npu_struct_out = npu_dict.get("output_struct")
28
45
  bench_struct_out = bench_dict.get("output_struct")
29
46
 
30
- if cross_frame:
31
- npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
32
- npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
33
47
  is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
34
48
  if not is_match:
35
49
  if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
36
50
  return False
37
- struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
38
- struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
51
+ try:
52
+ struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
53
+ struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
54
+ except CompareException as error:
55
+ err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
56
+ f'npu_dict: {npu_dict}' \
57
+ f'bench_dict: {bench_dict}'
58
+ logger.error(err_msg)
59
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
39
60
  is_match = struct_in_is_match and struct_out_is_match
40
61
  return is_match
41
62
 
@@ -43,17 +64,27 @@ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
43
64
  def check_type_shape_match(npu_struct, bench_struct):
44
65
  shape_type_match = False
45
66
  for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
46
- npu_type = npu_type_shape[0]
47
- npu_shape = npu_type_shape[1]
48
- bench_type = bench_type_shape[0]
49
- bench_shape = bench_type_shape[1]
67
+ try:
68
+ npu_type = npu_type_shape[0]
69
+ npu_shape = npu_type_shape[1]
70
+ bench_type = bench_type_shape[0]
71
+ bench_shape = bench_type_shape[1]
72
+ except IndexError as error:
73
+ logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
74
+ f'should both be 2, please check!')
75
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
50
76
  shape_match = npu_shape == bench_shape
51
77
  type_match = npu_type == bench_type
52
78
  if not type_match:
53
- ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]]
54
- torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
55
- ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]
56
- if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type):
79
+ ms_type = [
80
+ [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
81
+ [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
82
+ ]
83
+ torch_type = [
84
+ [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
85
+ [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
86
+ ]
87
+ if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
57
88
  type_match = True
58
89
  else:
59
90
  type_match = False
@@ -64,9 +95,9 @@ def check_type_shape_match(npu_struct, bench_struct):
64
95
 
65
96
 
66
97
  def check_graph_mode(a_op_name, b_op_name):
67
- if "Aten" in a_op_name and "Aten" not in b_op_name:
98
+ if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
68
99
  return True
69
- if "Aten" not in a_op_name and "Aten" in b_op_name:
100
+ if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
70
101
  return True
71
102
  return False
72
103
 
@@ -83,13 +114,64 @@ def fuzzy_check_op(npu_name_list, bench_name_list):
83
114
 
84
115
 
85
116
  def fuzzy_check_name(npu_name, bench_name):
86
- if "forward" in npu_name and "forward" in bench_name:
87
- is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
88
- elif "backward" in npu_name and "backward" in bench_name:
89
- is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
117
+ if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
118
+ is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
119
+ elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
120
+ is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
90
121
  else:
91
122
  is_match = npu_name == bench_name
92
123
  return is_match
93
124
 
94
125
 
126
+ def check_dump_json_str(op_data, op_name):
127
+ input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
128
+ Const.INPUT, None)
129
+ input_kwargs = op_data.get(Const.INPUT_KWARGS, None)
130
+ output_list = op_data.get(Const.OUTPUT, None)
131
+
132
+ args = [input_list, input_kwargs, output_list]
133
+ for arg in args:
134
+ if not arg:
135
+ continue
136
+ if isinstance(arg, dict):
137
+ check_json_key_value(arg, op_name)
138
+ else:
139
+ for ele in arg:
140
+ if not ele:
141
+ continue
142
+ check_json_key_value(ele, op_name)
143
+
144
+
145
+ def check_json_key_value(input_output, op_name, depth=0):
146
+ if depth > Const.MAX_DEPTH:
147
+ logger.error(f"string check of data info of {op_name} exceeds the recursion limit.")
148
+ return
149
+ if isinstance(input_output, list):
150
+ for item in input_output:
151
+ check_json_key_value(item, op_name, depth+1)
152
+ elif isinstance(input_output, dict):
153
+ for key, value in input_output.items():
154
+ if isinstance(value, dict):
155
+ check_json_key_value(value, op_name, depth+1)
156
+ else:
157
+ valid_key_value(key, value, op_name)
158
+
95
159
 
160
+ def valid_key_value(key, value, op_name):
161
+ if key == "shape" and not isinstance(value, (list, tuple)):
162
+ logger.error(f"shape of input or output of {op_name} is not list or tuple, please check!")
163
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
164
+ elif key == "requires_grad" and not isinstance(value, bool):
165
+ logger.error(f"requires_grad of input or output of {op_name} is not bool, please check!")
166
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
167
+ else:
168
+ check_op_str_pattern_valid(value, op_name)
169
+
170
+
171
+ def check_stack_json_str(stack_info, op_name):
172
+ if isinstance(stack_info, list):
173
+ for item in stack_info:
174
+ check_op_str_pattern_valid(item, op_name, stack=True)
175
+ else:
176
+ logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'")
177
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
@@ -1,15 +1,35 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import json
2
- from msprobe.core.common.file_utils import FileOpen, check_file_type
17
+ from msprobe.core.common.file_utils import check_file_type, load_json
3
18
  from msprobe.core.common.const import FileCheckConst, Const
4
19
  from msprobe.core.common.utils import CompareException
5
20
  from msprobe.core.common.log import logger
6
21
 
7
22
 
8
23
  def compare_cli(args):
9
- with FileOpen(args.input_path, "r") as file:
10
- input_param = json.load(file)
24
+ input_param = load_json(args.input_path)
11
25
  npu_path = input_param.get("npu_path", None)
12
26
  bench_path = input_param.get("bench_path", None)
27
+ if not npu_path:
28
+ logger.error(f"Missing npu_path in configuration file {args.input_path}, please check!")
29
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
30
+ if not bench_path:
31
+ logger.error(f"Missing bench_path in configuration file {args.input_path}, please check!")
32
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
13
33
  frame_name = args.framework
14
34
  auto_analyze = not args.compare_only
15
35
  if frame_name == Const.PT_FRAMEWORK:
@@ -19,12 +39,18 @@ def compare_cli(args):
19
39
  from msprobe.mindspore.compare.ms_compare import ms_compare
20
40
  from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
21
41
  if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
42
+ if "stack_path" not in input_param:
43
+ logger.error(f"Missing stack_path in configuration file {args.input_path}, please check!")
44
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
22
45
  input_param["npu_json_path"] = input_param.pop("npu_path")
23
46
  input_param["bench_json_path"] = input_param.pop("bench_path")
24
47
  input_param["stack_json_path"] = input_param.pop("stack_path")
25
48
  if frame_name == Const.PT_FRAMEWORK:
49
+ kwargs = {
50
+ "data_mapping": args.data_mapping
51
+ }
26
52
  compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
27
- fuzzy_match=args.fuzzy_match)
53
+ fuzzy_match=args.fuzzy_match, **kwargs)
28
54
  else:
29
55
  kwargs = {
30
56
  "stack_mode": args.stack_mode,
@@ -32,11 +58,22 @@ def compare_cli(args):
32
58
  "fuzzy_match": args.fuzzy_match,
33
59
  "cell_mapping": args.cell_mapping,
34
60
  "api_mapping": args.api_mapping,
61
+ "data_mapping": args.data_mapping,
62
+ "layer_mapping": args.layer_mapping
35
63
  }
36
64
 
37
65
  ms_compare(input_param, args.output_path, **kwargs)
38
66
  elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
39
- kwargs = {"stack_mode": args.stack_mode, "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match}
67
+ kwargs = {
68
+ "stack_mode": args.stack_mode,
69
+ "auto_analyze": auto_analyze,
70
+ "fuzzy_match": args.fuzzy_match,
71
+ "is_print_compare_log": input_param.get("is_print_compare_log", True),
72
+ "cell_mapping": args.cell_mapping,
73
+ "api_mapping": args.api_mapping,
74
+ "data_mapping": args.data_mapping,
75
+ "layer_mapping": args.layer_mapping
76
+ }
40
77
  if input_param.get("rank_id") is not None:
41
78
  ms_graph_compare(input_param, args.output_path)
42
79
  return
@@ -1,89 +1,127 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  import abc
18
+ import re
3
19
  from collections import namedtuple
4
20
  import numpy as np
5
21
  import openpyxl
6
22
  from openpyxl.styles import PatternFill
23
+ from tqdm import tqdm
7
24
  from msprobe.core.common.utils import get_header_index
8
25
  from msprobe.core.common.file_utils import save_workbook
9
26
  from msprobe.core.common.log import logger
10
- from msprobe.core.common.const import CompareConst
27
+ from msprobe.core.common.const import CompareConst, FileCheckConst, Const
28
+ from msprobe.core.common.utils import safe_get_value
11
29
 
12
30
 
13
31
  class HighlightCheck(abc.ABC):
14
32
  @abc.abstractmethod
15
- def apply(self, info, color_columns, summary_compare):
33
+ def apply(self, info, color_columns, dump_mode):
16
34
  raise NotImplementedError
17
35
 
18
36
 
37
+ def add_highlight_row_info(color_list, num, highlight_err_msg):
38
+ for i, (existing_num, existing_err_msg) in enumerate(color_list):
39
+ if num == existing_num:
40
+ color_list[i][1].append(highlight_err_msg)
41
+ return
42
+ color_list.append((num, [highlight_err_msg]))
43
+
44
+
19
45
  class CheckOrderMagnitude(HighlightCheck):
20
46
  """检查Max diff的数量级差异"""
21
- def apply(self, info, color_columns, summary_compare=True):
47
+ def apply(self, info, color_columns, dump_mode):
22
48
  api_in, api_out, num = info
23
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
49
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
50
+ else CompareConst.MAX_ABS_ERR, dump_mode)
24
51
  if abs(api_in[max_diff_index]) > abs(api_out[max_diff_index]):
25
52
  return
26
53
  in_order = 0 if abs(api_in[max_diff_index]) < 1 else math.log10(abs(api_in[max_diff_index]))
27
54
  out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
28
55
  if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
29
- color_columns.yellow.append(num)
56
+ add_highlight_row_info(color_columns.yellow, num,
57
+ "maximum absolute error of both input and output exceed 1, "
58
+ "with the output larger by an order of magnitude")
30
59
 
31
60
 
32
61
  class CheckOneThousandErrorRatio(HighlightCheck):
33
62
  """检查千分误差比率"""
34
- def apply(self, info, color_columns, summary_compare=True):
63
+ def apply(self, info, color_columns, dump_mode):
35
64
  api_in, api_out, num = info
36
- one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
37
- if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
65
+ one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
66
+ if (not isinstance(api_in[one_thousand_index], (float, int)) or
67
+ not isinstance(api_out[one_thousand_index], (float, int))):
38
68
  return
39
- if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
40
- color_columns.red.append(num)
69
+ if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
70
+ api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
71
+ add_highlight_row_info(color_columns.red, num,
72
+ "The input's one thousandth err ratio exceeds 0.9, while the output's is below 0.6")
41
73
  elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
42
- color_columns.yellow.append(num)
74
+ add_highlight_row_info(color_columns.yellow, num,
75
+ "The output's one thousandth err ratio decreases by more than 0.1 "
76
+ "compared to the input's")
43
77
 
44
78
 
45
79
  class CheckCosineSimilarity(HighlightCheck):
46
80
  """检查余弦相似度"""
47
- def apply(self, info, color_columns, summary_compare=True):
81
+ def apply(self, info, color_columns, dump_mode):
48
82
  api_in, api_out, num = info
49
- cosine_index = get_header_index('Cosine', summary_compare)
83
+ cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
50
84
  if not isinstance(api_in[cosine_index], (float, int)) or not isinstance(api_out[cosine_index], (float, int)):
51
85
  return
52
86
  if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
53
- color_columns.yellow.append(num)
87
+ add_highlight_row_info(color_columns.yellow, num,
88
+ "The output's cosine decreases by more than 0.1 compared to the input's")
54
89
 
55
90
 
56
91
  class CheckMaxRelativeDiff(HighlightCheck):
57
92
  """检查最大相对差异"""
58
- def apply(self, info, color_columns, summary_compare=True):
93
+ def apply(self, info, color_columns, dump_mode):
59
94
  api_in, api_out, num = info
60
- max_diff_index = get_header_index('Max diff', summary_compare)
61
- bench_max_index = get_header_index('Bench max', summary_compare)
95
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
96
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
62
97
  input_max_relative_diff = np.abs(np.divide(api_in[max_diff_index], max(0.01, api_in[bench_max_index])))
63
98
  output_max_relative_diff = np.abs(np.divide(api_out[max_diff_index], max(0.01, api_out[bench_max_index])))
64
99
  if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
65
100
  (float, int)):
66
101
  return
67
102
  if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
68
- color_columns.red.append(num)
69
- elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
70
- color_columns.yellow.append(num)
103
+ add_highlight_row_info(color_columns.red, num, "maximum relative error exceeds 0.5")
104
+ elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
105
+ input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
106
+ add_highlight_row_info(color_columns.yellow, num,
107
+ "The output's maximum relative error exceeds 0.1, while the input's is below 0.01")
71
108
 
72
109
 
73
110
  class CheckOverflow(HighlightCheck):
74
111
  """检查是否存在溢出"""
75
- def apply(self, info, color_columns, summary_compare=True):
112
+ def apply(self, info, color_columns, dump_mode):
76
113
  line, num = info
77
- npu_max_index = get_header_index('NPU max', summary_compare)
78
- npu_min_index = get_header_index('NPU min', summary_compare)
79
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
114
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
115
+ npu_min_index = get_header_index(CompareConst.NPU_MIN, dump_mode)
116
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
117
+ else CompareConst.MAX_ABS_ERR, dump_mode)
80
118
  if str(line[npu_max_index]) in CompareConst.OVERFLOW_LIST or str(
81
119
  line[npu_min_index]) in CompareConst.OVERFLOW_LIST:
82
- color_columns.red.append(num)
120
+ add_highlight_row_info(color_columns.red, num, "maximum or minimum is nan, -inf, or inf")
83
121
  return
84
122
  # check if Max_Diff > 1e+10
85
- if isinstance(line[max_diff_index], (float, int)) and line[max_diff_index] > CompareConst.MAX_DIFF_RED:
86
- color_columns.red.append(num)
123
+ if isinstance(line[max_diff_index], (float, int)) and abs(line[max_diff_index]) > CompareConst.MAX_DIFF_RED:
124
+ add_highlight_row_info(color_columns.red, num, "maximum absolute error exceeds 1e+10")
87
125
 
88
126
 
89
127
  class HighlightRules:
@@ -105,13 +143,14 @@ class HighlightRules:
105
143
  }
106
144
 
107
145
 
108
- def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False):
146
+ def find_error_rows(result, last_len, n_num_input, highlight_dict, dump_mode):
109
147
  """找到单个API中需要高亮的行"""
110
- if md5_compare:
148
+ if dump_mode == Const.MD5:
111
149
  return
112
- npu_max_index = get_header_index('NPU max', summary_compare)
113
- bench_max_index = get_header_index('Bench max', summary_compare)
114
- max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare)
150
+ npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
151
+ bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
152
+ max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
153
+ else CompareConst.MAX_ABS_ERR, dump_mode)
115
154
 
116
155
  red_lines, yellow_lines = [], []
117
156
  LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
@@ -124,7 +163,7 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
124
163
  num = last_len + i
125
164
  line_info = LineInfo(line_data=line, num_pointer=num)
126
165
  for rule in HighlightRules.basic_rules.values():
127
- rule.apply(line_info, color_columns, summary_compare)
166
+ rule.apply(line_info, color_columns, dump_mode)
128
167
 
129
168
  # 对API的输出与输入比较,进行误差判断
130
169
  for n, api_out in enumerate(result[n_num_input:len(result)]):
@@ -142,36 +181,42 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa
142
181
  continue
143
182
 
144
183
  api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num)
145
- if summary_compare:
184
+ if dump_mode == Const.SUMMARY:
146
185
  for rule in HighlightRules.summary_compare_rules.values():
147
- rule.apply(api_info, color_columns, summary_compare)
186
+ rule.apply(api_info, color_columns, dump_mode)
148
187
  else:
149
188
  for rule in HighlightRules.compare_rules.values():
150
- rule.apply(api_info, color_columns, summary_compare)
189
+ rule.apply(api_info, color_columns, dump_mode)
151
190
 
152
- highlight_dict.get('red_rows', []).extend(list(set(red_lines)))
153
- highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines)))
191
+ red_lines_num_set = {x[0] for x in red_lines}
192
+ yellow_lines_num_set = {x[0] for x in yellow_lines}
193
+ highlight_dict.get('red_rows', set()).update(red_lines_num_set)
194
+ highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
195
+ highlight_dict.get('red_lines', []).extend(red_lines)
196
+ highlight_dict.get('yellow_lines', []).extend(yellow_lines)
154
197
 
155
198
 
156
199
  def get_name_and_state(name):
157
200
  """Get api/module name and state"""
158
- if "input" in name:
159
- api_name = name.split("input")[0]
160
- state = "input"
201
+ if Const.INPUT in name:
202
+ api_name = name.split(Const.INPUT)[0]
203
+ state = Const.INPUT
161
204
  else:
162
- api_name = name.split("output")[0]
163
- state = "output"
205
+ api_name = name.split(Const.OUTPUT)[0]
206
+ state = Const.OUTPUT
164
207
  return api_name, state
165
208
 
166
209
 
167
- def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare):
210
+ def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
168
211
  """将dataframe根据API分组,并找到有误差的算子用于高亮"""
169
212
  result = result_df.values
170
213
  start, input_num, output_num, end = 0, 0, 0, len(result_df)
171
214
  last_api_name, last_state = None, None
172
215
  num, last_len = 0, 0
216
+ progress_bar = tqdm(total=len(result), desc="API/Module Analyse Progress", unit="item", ncols=100)
173
217
  for res_i in result:
174
- api_name, state = get_name_and_state(res_i[0])
218
+ api_full_name = safe_get_value(res_i, 0, "res_i")
219
+ api_name, state = get_name_and_state(api_full_name)
175
220
  if last_api_name:
176
221
  if api_name == last_api_name:
177
222
  if state == last_state:
@@ -182,42 +227,102 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m
182
227
  else:
183
228
  output_num = num
184
229
  find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
185
- summary_compare, md5_compare)
230
+ dump_mode)
186
231
  num, last_api_name, last_state = 1, api_name, state
187
232
  start += input_num + output_num
188
233
  input_num, output_num = 1, 0
189
234
  else:
190
235
  num, last_api_name, last_state = 1, api_name, state
236
+ progress_bar.update(1)
237
+ progress_bar.close()
191
238
  if state:
192
- if state == "input":
239
+ if state == Const.INPUT:
193
240
  input_num = num
194
241
  else:
195
242
  output_num = num
196
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
243
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
244
+ dump_mode)
197
245
 
198
246
 
199
247
  def highlight_rows_xlsx(result_df, highlight_dict, file_path):
200
248
  """Write and highlight results in Excel"""
201
- logger.info('Compare result is %s' % file_path)
249
+
250
+ update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
202
251
 
203
252
  wb = openpyxl.Workbook()
204
253
  ws = wb.active
205
254
 
206
255
  # write header
256
+ logger.info('Initializing Excel file.')
207
257
  for j, col_name in enumerate(result_df.columns, start=1):
258
+ if not csv_value_is_valid(col_name):
259
+ raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
208
260
  ws.cell(row=1, column=j, value=col_name)
209
261
 
210
262
  for i, row in enumerate(result_df.iterrows(), start=2):
211
263
  for j, value in enumerate(row[1], start=1):
212
- if not isinstance(value, (float, int)):
264
+ if not isinstance(value, (float, int)) or isinstance(value, bool):
213
265
  value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
266
+ if not csv_value_is_valid(value):
267
+ raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: "
268
+ f"{file_path}.")
214
269
  ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
270
+
271
+ # 对可疑数据标色
272
+ logger.info('Coloring Excel in progress.')
273
+ col_len = len(result_df.columns)
274
+ red_fill = PatternFill(
275
+ start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
276
+ )
277
+ yellow_fill = PatternFill(
278
+ start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
279
+ )
280
+ for i in highlight_dict.get("red_rows", []):
281
+ for j in range(1, col_len + 1):
282
+ ws.cell(row=i + 2, column=j).fill = red_fill
283
+ for i in highlight_dict.get("yellow_rows", []):
284
+ for j in range(1, col_len + 1):
285
+ ws.cell(row=i + 2, column=j).fill = yellow_fill
286
+ logger.info('Saving Excel file to disk: %s' % file_path)
287
+ save_workbook(wb, file_path)
215
288
 
216
- if (i - 2) in highlight_dict['red_rows']:
217
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED,
218
- end_color=CompareConst.RED, fill_type="solid")
219
- elif (i - 2) in highlight_dict['yellow_rows']:
220
- ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
221
- end_color=CompareConst.YELLOW, fill_type="solid")
222
289
 
223
- save_workbook(wb, file_path)
290
+ def update_highlight_err_msg(result_df, highlight_dict):
291
+ if result_df.shape[1] <= 1:
292
+ return
293
+
294
+ if CompareConst.NPU_MD5 in result_df.columns:
295
+ return
296
+
297
+ err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
298
+ red_lines_num_set = highlight_dict.get('red_rows')
299
+
300
+ for color in ['red', 'yellow']:
301
+ line_key = f'{color}_lines'
302
+ lines = highlight_dict.get(line_key, [])
303
+ for line_index, messages in lines:
304
+ if color == 'yellow' and line_index in red_lines_num_set:
305
+ continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
306
+
307
+ for msg in messages:
308
+ if err_msg[line_index] == '':
309
+ err_msg[line_index] = msg
310
+ else:
311
+ err_msg[line_index] += '\n' + msg
312
+
313
+ if color == 'red':
314
+ red_lines_num_set.add(line_index)
315
+
316
+ result_df[CompareConst.ERROR_MESSAGE] = err_msg
317
+
318
+
319
+ def csv_value_is_valid(value: str) -> bool:
320
+ if not isinstance(value, str):
321
+ return True
322
+ try:
323
+ # -1.00 or +1.00 should be consdiered as digit numbers
324
+ float(value)
325
+ except ValueError:
326
+ # otherwise, they will be considered as formular injections
327
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
328
+ return True
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from msprobe.core.compare.layer_mapping.layer_mapping import (
17
+ generate_data_mapping_by_layer_mapping,
18
+ generate_api_mapping_by_layer_mapping,
19
+ )