mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,16 +1,45 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import abc
17
+
2
18
  import numpy as np
3
19
  from msprobe.core.common.utils import format_value
4
20
  from msprobe.core.common.const import Const, CompareConst
5
21
  from msprobe.core.common.log import logger
6
22
 
23
+ from msprobe.core.common.utils import CompareException
24
+
7
25
 
8
26
  def handle_inf_nan(n_value, b_value):
27
+ def convert_to_float(value):
28
+ try:
29
+ if isinstance(value, np.ndarray):
30
+ return value.astype(float)
31
+ else:
32
+ return float(value)
33
+ except ValueError as e:
34
+ logger.error('\n'.join(e.args))
35
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
36
+
37
+ n_value_convert, b_value_convert = convert_to_float(n_value), convert_to_float(b_value)
9
38
  """处理inf和nan的数据"""
10
- n_inf = np.isinf(n_value)
11
- b_inf = np.isinf(b_value)
12
- n_nan = np.isnan(n_value)
13
- b_nan = np.isnan(b_value)
39
+ n_inf = np.isinf(n_value_convert)
40
+ b_inf = np.isinf(b_value_convert)
41
+ n_nan = np.isnan(n_value_convert)
42
+ b_nan = np.isnan(b_value_convert)
14
43
  n_invalid = np.any(n_inf) or np.any(n_nan)
15
44
  b_invalid = np.any(b_inf) or np.any(b_nan)
16
45
  if n_invalid or b_invalid:
@@ -35,7 +64,11 @@ def get_error_type(n_value, b_value, error_flag):
35
64
  if not n_value.shape: # 判断数据是否为标量
36
65
  return n_value, b_value, False
37
66
 
38
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
67
+ try:
68
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
69
+ except CompareException:
70
+ logger.error('Numpy data is unreadable, please check!')
71
+ return CompareConst.UNREADABLE, CompareConst.UNREADABLE, True
39
72
  if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
40
73
  return CompareConst.NAN, CompareConst.NAN, True
41
74
  return n_value, b_value, False
@@ -58,7 +91,9 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
58
91
  """获取异常情况的错误信息"""
59
92
  if error_flag:
60
93
  if n_value == CompareConst.READ_NONE:
61
- if error_file:
94
+ if error_file == 'no_bench_data':
95
+ return 'Bench does not have data file.'
96
+ elif error_file is not None:
62
97
  return "Dump file: {} not found.".format(error_file)
63
98
  return CompareConst.NO_BENCH
64
99
  if n_value == CompareConst.NONE:
@@ -67,6 +102,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
67
102
  return "Shape of NPU and bench Tensor do not match. Skipped."
68
103
  if n_value == CompareConst.NAN:
69
104
  return "The position of inf or nan in NPU and bench Tensor do not match."
105
+ if n_value == CompareConst.UNREADABLE:
106
+ return "The npy data is unable to be read or compared, please check dump data files."
70
107
  else:
71
108
  if not n_value.shape:
72
109
  return "This is type of scalar data, can not compare."
@@ -78,10 +115,8 @@ def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None
78
115
 
79
116
  def npy_data_check(n_value, b_value):
80
117
  error_message = ""
81
- if n_value is None or b_value is None:
82
- error_message += "Dump file not found.\n"
83
- if n_value == "" or b_value == "":
84
- error_message += "Dump file not found.\n"
118
+ if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
119
+ error_message += "Dump file is not ndarray.\n"
85
120
 
86
121
  # 检查 n_value 和 b_value 是否为空
87
122
  if not error_message and (n_value.size == 0 or b_value.size == 0):
@@ -96,8 +131,13 @@ def npy_data_check(n_value, b_value):
96
131
  error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
97
132
 
98
133
  if not error_message:
99
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
100
- if CompareConst.NAN in (n_value, b_value):
134
+ try:
135
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
136
+ except CompareException:
137
+ logger.error('Numpy data is unreadable, please check!')
138
+ return True, 'Numpy data is unreadable, please check!'
139
+ # handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
140
+ if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
101
141
  error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
102
142
  if error_message == "":
103
143
  error_flag = False
@@ -146,14 +186,14 @@ class GetCosineSimilarity(TensorComparisonBasic):
146
186
 
147
187
  def apply(self, n_value, b_value, error_flag, relative_err=None):
148
188
  if error_flag:
149
- if n_value == CompareConst.READ_NONE:
150
- return CompareConst.NONE, ''
189
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
190
+ return CompareConst.UNSUPPORTED, ''
151
191
  if n_value == CompareConst.NONE:
152
192
  return CompareConst.UNSUPPORTED, ''
153
193
  if n_value == CompareConst.SHAPE_UNMATCH:
154
194
  return CompareConst.SHAPE_UNMATCH, ''
155
195
  if n_value == CompareConst.NAN:
156
- return "N/A", ''
196
+ return CompareConst.N_A, ''
157
197
 
158
198
  if not n_value.shape:
159
199
  return CompareConst.UNSUPPORTED, ''
@@ -184,17 +224,20 @@ class GetMaxAbsErr(TensorComparisonBasic):
184
224
  """计算最大绝对误差"""
185
225
  def apply(self, n_value, b_value, error_flag, relative_err=None):
186
226
  if error_flag:
187
- if n_value == CompareConst.READ_NONE:
188
- return CompareConst.NONE, ""
227
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
228
+ return CompareConst.UNSUPPORTED, ""
189
229
  if n_value == CompareConst.NONE:
190
230
  return 0, ""
191
231
  if n_value == CompareConst.SHAPE_UNMATCH:
192
232
  return CompareConst.SHAPE_UNMATCH, ""
193
233
  if n_value == CompareConst.NAN:
194
- return "N/A", ""
234
+ return CompareConst.N_A, ""
195
235
 
196
236
  temp_res = n_value - b_value
197
237
  max_value = np.max(np.abs(temp_res))
238
+ if np.isnan(max_value):
239
+ message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
240
+ return CompareConst.NAN, message
198
241
  return format_value(max_value), ""
199
242
 
200
243
 
@@ -214,20 +257,20 @@ class GetMaxRelativeErr(TensorComparisonBasic):
214
257
  """计算最大相对误差"""
215
258
  def apply(self, n_value, b_value, error_flag, relative_err=None):
216
259
  if error_flag:
217
- if n_value == CompareConst.READ_NONE:
218
- return CompareConst.NONE, ''
260
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
261
+ return CompareConst.UNSUPPORTED, ''
219
262
  if n_value == CompareConst.NONE:
220
263
  return 0, ''
221
264
  if n_value == CompareConst.SHAPE_UNMATCH:
222
265
  return CompareConst.SHAPE_UNMATCH, ''
223
266
  if n_value == CompareConst.NAN:
224
- return "N/A", ''
267
+ return CompareConst.N_A, ''
225
268
 
226
269
  if relative_err is None:
227
270
  relative_err = get_relative_err(n_value, b_value)
228
271
  max_relative_err = np.max(np.abs(relative_err))
229
272
  if np.isnan(max_relative_err):
230
- message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
273
+ message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
231
274
  return CompareConst.NAN, message
232
275
  return format_value(max_relative_err), ''
233
276
 
@@ -236,14 +279,14 @@ class GetThousandErrRatio(TensorComparisonBasic):
236
279
  """计算相对误差小于千分之一的比例"""
237
280
  def apply(self, n_value, b_value, error_flag, relative_err=None):
238
281
  if error_flag:
239
- if n_value == CompareConst.READ_NONE:
240
- return CompareConst.NONE, ""
282
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
283
+ return CompareConst.UNSUPPORTED, ""
241
284
  if n_value == CompareConst.NONE:
242
285
  return 0, ""
243
286
  if n_value == CompareConst.SHAPE_UNMATCH:
244
287
  return CompareConst.SHAPE_UNMATCH, ""
245
288
  if n_value == CompareConst.NAN:
246
- return "N/A", ""
289
+ return CompareConst.N_A, ""
247
290
 
248
291
  if not n_value.shape:
249
292
  return CompareConst.NAN, ""
@@ -258,14 +301,14 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
258
301
  """计算相对误差小于千分之五的比例"""
259
302
  def apply(self, n_value, b_value, error_flag, relative_err=None):
260
303
  if error_flag:
261
- if n_value == CompareConst.READ_NONE:
262
- return CompareConst.NONE, ""
304
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
305
+ return CompareConst.UNSUPPORTED, ""
263
306
  if n_value == CompareConst.NONE:
264
307
  return 0, ""
265
308
  if n_value == CompareConst.SHAPE_UNMATCH:
266
309
  return CompareConst.SHAPE_UNMATCH, ""
267
310
  if n_value == CompareConst.NAN:
268
- return "N/A", ""
311
+ return CompareConst.N_A, ""
269
312
 
270
313
  if not n_value.shape:
271
314
  return CompareConst.NAN, ""
@@ -273,7 +316,8 @@ class GetFiveThousandErrRatio(TensorComparisonBasic):
273
316
  relative_err = get_relative_err(n_value, b_value)
274
317
  if not np.size(relative_err):
275
318
  return CompareConst.NAN, ""
276
- return format_value(np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
319
+ return format_value(
320
+ np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
277
321
 
278
322
 
279
323
  class CompareOps: