mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -14,18 +14,31 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import abc
17
+
17
18
  import numpy as np
18
- from msprobe.core.common.utils import format_value
19
+
19
20
  from msprobe.core.common.const import Const, CompareConst
20
21
  from msprobe.core.common.log import logger
22
+ from msprobe.core.common.utils import CompareException, format_value
21
23
 
22
24
 
23
25
  def handle_inf_nan(n_value, b_value):
26
+ def convert_to_float(value):
27
+ try:
28
+ if isinstance(value, np.ndarray):
29
+ return value.astype(float)
30
+ else:
31
+ return float(value)
32
+ except ValueError as e:
33
+ logger.error('\n'.join(e.args))
34
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
35
+
36
+ n_value_convert, b_value_convert = convert_to_float(n_value), convert_to_float(b_value)
24
37
  """处理inf和nan的数据"""
25
- n_inf = np.isinf(n_value)
26
- b_inf = np.isinf(b_value)
27
- n_nan = np.isnan(n_value)
28
- b_nan = np.isnan(b_value)
38
+ n_inf = np.isinf(n_value_convert)
39
+ b_inf = np.isinf(b_value_convert)
40
+ n_nan = np.isnan(n_value_convert)
41
+ b_nan = np.isnan(b_value_convert)
29
42
  n_invalid = np.any(n_inf) or np.any(n_nan)
30
43
  b_invalid = np.any(b_inf) or np.any(b_nan)
31
44
  if n_invalid or b_invalid:
@@ -39,58 +52,66 @@ def handle_inf_nan(n_value, b_value):
39
52
  return n_value, b_value
40
53
 
41
54
 
42
- def get_error_type(n_value, b_value, error_flag):
43
- """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
55
+ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
56
+ """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag和error_msg"""
57
+ err_msg = ""
44
58
  if error_flag:
45
- return CompareConst.READ_NONE, CompareConst.READ_NONE, True
59
+ if error_file == "no_bench_data":
60
+ err_msg = "Bench does not have data file."
61
+ elif error_file:
62
+ err_msg = f"Dump file: {error_file} not found."
63
+ else:
64
+ err_msg = CompareConst.NO_BENCH
65
+ error_flag = True
66
+ return CompareConst.READ_NONE, CompareConst.READ_NONE, error_flag, err_msg
67
+
46
68
  if n_value.size == 0: # 判断读取到的数据是否为空
47
- return CompareConst.NONE, CompareConst.NONE, True
69
+ err_msg = "This is empty data, can not compare."
70
+ error_flag = True
71
+ return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
72
+ if not n_value.shape: # 判断数据是否为0维张量
73
+ err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
74
+ f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
75
+ error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
76
+ return n_value, b_value, error_flag, err_msg
48
77
  if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
49
- return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, True
50
- if not n_value.shape: # 判断数据是否为标量
51
- return n_value, b_value, False
78
+ err_msg = "Shape of NPU and bench tensor do not match. Skipped."
79
+ error_flag = True
80
+ return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, error_flag, err_msg
52
81
 
53
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
82
+ try:
83
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
84
+ except CompareException:
85
+ logger.error('Numpy data is unreadable, please check!')
86
+ err_msg = "Data is unreadable."
87
+ error_flag = True
88
+ return CompareConst.UNREADABLE, CompareConst.UNREADABLE, error_flag, err_msg
54
89
  if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
55
- return CompareConst.NAN, CompareConst.NAN, True
56
- return n_value, b_value, False
90
+ err_msg = "The position of inf or nan in NPU and bench Tensor do not match."
91
+ error_flag = True
92
+ return CompareConst.NAN, CompareConst.NAN, error_flag, err_msg
93
+
94
+ if n_value.dtype != b_value.dtype: # 判断数据的dtype是否一致
95
+ err_msg = "Dtype of NPU and bench tensor do not match."
96
+ error_flag = False
97
+ return n_value, b_value, error_flag, err_msg
98
+
99
+ return n_value, b_value, error_flag, err_msg
57
100
 
58
101
 
59
102
  def reshape_value(n_value, b_value):
60
103
  """返回reshape后的数据"""
61
- if not n_value.shape: # 判断数据是否为标量
104
+ if not n_value.shape: # 判断数据是否为0维tensor, 如果0维tensor,不会转成1维tensor,直接返回
62
105
  if n_value.dtype == bool:
63
106
  n_value = n_value.astype(float)
64
107
  b_value = b_value.astype(float)
65
108
  return n_value, b_value
66
109
 
67
- n_value = n_value.reshape(-1).astype(float)
110
+ n_value = n_value.reshape(-1).astype(float) # 32转64为了防止某些数转dataframe时出现误差
68
111
  b_value = b_value.reshape(-1).astype(float)
69
112
  return n_value, b_value
70
113
 
71
114
 
72
- def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
73
- """获取异常情况的错误信息"""
74
- if error_flag:
75
- if n_value == CompareConst.READ_NONE:
76
- if error_file:
77
- return "Dump file: {} not found.".format(error_file)
78
- return CompareConst.NO_BENCH
79
- if n_value == CompareConst.NONE:
80
- return "This is empty data, can not compare."
81
- if n_value == CompareConst.SHAPE_UNMATCH:
82
- return "Shape of NPU and bench Tensor do not match. Skipped."
83
- if n_value == CompareConst.NAN:
84
- return "The position of inf or nan in NPU and bench Tensor do not match."
85
- else:
86
- if not n_value.shape:
87
- return "This is type of scalar data, can not compare."
88
- if n_value.dtype != b_value.dtype:
89
- logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
90
- return "Dtype of NPU and bench Tensor do not match."
91
- return ""
92
-
93
-
94
115
  def npy_data_check(n_value, b_value):
95
116
  error_message = ""
96
117
  if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
@@ -109,7 +130,11 @@ def npy_data_check(n_value, b_value):
109
130
  error_message += "Dtype of NPU and bench Tensor do not match. Skipped.\n"
110
131
 
111
132
  if not error_message:
112
- n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有 nan/inf 数据
133
+ try:
134
+ n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
135
+ except CompareException:
136
+ logger.error('Numpy data is unreadable, please check!')
137
+ return True, 'Numpy data is unreadable, please check!'
113
138
  # handle_inf_nan 会返回'Nan'或ndarray类型,使用类型判断是否存在无法处理的nan/inf数据
114
139
  if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
115
140
  error_message += "The position of inf or nan in NPU and bench Tensor do not match.\n"
@@ -144,10 +169,25 @@ def statistics_data_check(result_dict):
144
169
  class TensorComparisonBasic(abc.ABC):
145
170
  """NPU和bench中npy数据的比较模板"""
146
171
  @abc.abstractmethod
147
- def apply(self, n_value, b_value, error_flag, relative_err=None):
172
+ def apply(self, n_value, b_value, relative_err):
148
173
  raise NotImplementedError
149
174
 
150
175
 
176
+ def get_relative_err(n_value, b_value):
177
+ """计算相对误差"""
178
+ with np.errstate(divide='ignore', invalid='ignore'):
179
+ if b_value.dtype not in CompareConst.FLOAT_TYPE:
180
+ n_value, b_value = n_value.astype(float), b_value.astype(float)
181
+
182
+ n_value_copy = n_value.copy()
183
+ b_value_copy = b_value.copy()
184
+ zero_mask = (b_value_copy == 0)
185
+ b_value_copy[zero_mask] += Const.FLOAT_EPSILON
186
+ n_value_copy[zero_mask] += Const.FLOAT_EPSILON
187
+ relative_err = np.divide((n_value_copy - b_value_copy), b_value_copy)
188
+ return np.abs(relative_err)
189
+
190
+
151
191
  class GetCosineSimilarity(TensorComparisonBasic):
152
192
  """计算cosine相似度"""
153
193
  @staticmethod
@@ -158,137 +198,67 @@ class GetCosineSimilarity(TensorComparisonBasic):
158
198
  return round(float(result), 6)
159
199
  return result
160
200
 
161
- def apply(self, n_value, b_value, error_flag, relative_err=None):
162
- if error_flag:
163
- if n_value == CompareConst.READ_NONE:
164
- return CompareConst.NONE, ''
165
- if n_value == CompareConst.NONE:
166
- return CompareConst.UNSUPPORTED, ''
167
- if n_value == CompareConst.SHAPE_UNMATCH:
168
- return CompareConst.SHAPE_UNMATCH, ''
169
- if n_value == CompareConst.NAN:
170
- return "N/A", ''
171
-
201
+ def apply(self, n_value, b_value, relative_err):
172
202
  if not n_value.shape:
173
- return CompareConst.UNSUPPORTED, ''
203
+ return CompareConst.UNSUPPORTED, ""
174
204
 
175
- with np.errstate(divide='ignore', invalid='ignore'):
205
+ with np.errstate(divide="ignore", invalid="ignore"):
176
206
  if len(n_value) == 1:
177
- return CompareConst.UNSUPPORTED, "This tensor is scalar."
207
+ return CompareConst.UNSUPPORTED, "This is a 1-d tensor of length 1."
178
208
  num = n_value.dot(b_value)
179
209
  a_norm = np.linalg.norm(n_value)
180
210
  b_norm = np.linalg.norm(b_value)
181
211
 
182
212
  if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
183
- return 1.0, ''
213
+ return 1.0, ""
184
214
  if a_norm <= Const.FLOAT_EPSILON:
185
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.'
215
+ return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in npu dump data."
186
216
  if b_norm <= Const.FLOAT_EPSILON:
187
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.'
217
+ return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data."
188
218
 
189
219
  cos = num / (a_norm * b_norm)
190
220
  if np.isnan(cos):
191
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, the dump data has NaN.'
221
+ return CompareConst.NAN, "Cannot compare by Cosine Similarity, the dump data has NaN."
192
222
  result = format_value(cos)
193
223
  result = self.correct_data(result)
194
- return 1.0 if float(result) > 0.99999 else result, ''
224
+ return result, ""
195
225
 
196
226
 
197
227
  class GetMaxAbsErr(TensorComparisonBasic):
198
228
  """计算最大绝对误差"""
199
- def apply(self, n_value, b_value, error_flag, relative_err=None):
200
- if error_flag:
201
- if n_value == CompareConst.READ_NONE:
202
- return CompareConst.NONE, ""
203
- if n_value == CompareConst.NONE:
204
- return 0, ""
205
- if n_value == CompareConst.SHAPE_UNMATCH:
206
- return CompareConst.SHAPE_UNMATCH, ""
207
- if n_value == CompareConst.NAN:
208
- return "N/A", ""
209
-
229
+ def apply(self, n_value, b_value, relative_err):
210
230
  temp_res = n_value - b_value
211
231
  max_value = np.max(np.abs(temp_res))
232
+ if np.isnan(max_value):
233
+ msg = "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data."
234
+ return CompareConst.NAN, msg
212
235
  return format_value(max_value), ""
213
236
 
214
237
 
215
- def get_relative_err(n_value, b_value):
216
- """计算相对误差"""
217
- with np.errstate(divide='ignore', invalid='ignore'):
218
- if b_value.dtype not in CompareConst.FLOAT_TYPE:
219
- n_value, b_value = n_value.astype(float), b_value.astype(float)
220
- zero_mask = (b_value == 0)
221
- b_value[zero_mask] += np.finfo(b_value.dtype).eps
222
- n_value[zero_mask] += np.finfo(b_value.dtype).eps
223
- relative_err = np.divide((n_value - b_value), b_value)
224
- return np.abs(relative_err)
225
-
226
-
227
238
  class GetMaxRelativeErr(TensorComparisonBasic):
228
239
  """计算最大相对误差"""
229
- def apply(self, n_value, b_value, error_flag, relative_err=None):
230
- if error_flag:
231
- if n_value == CompareConst.READ_NONE:
232
- return CompareConst.NONE, ''
233
- if n_value == CompareConst.NONE:
234
- return 0, ''
235
- if n_value == CompareConst.SHAPE_UNMATCH:
236
- return CompareConst.SHAPE_UNMATCH, ''
237
- if n_value == CompareConst.NAN:
238
- return "N/A", ''
239
-
240
- if relative_err is None:
241
- relative_err = get_relative_err(n_value, b_value)
240
+ def apply(self, n_value, b_value, relative_err):
242
241
  max_relative_err = np.max(np.abs(relative_err))
243
242
  if np.isnan(max_relative_err):
244
- message = 'Cannot compare by MaxRelativeError, the data contains nan in dump data.'
245
- return CompareConst.NAN, message
246
- return format_value(max_relative_err), ''
247
-
248
-
249
- class GetThousandErrRatio(TensorComparisonBasic):
250
- """计算相对误差小于千分之一的比例"""
251
- def apply(self, n_value, b_value, error_flag, relative_err=None):
252
- if error_flag:
253
- if n_value == CompareConst.READ_NONE:
254
- return CompareConst.NONE, ""
255
- if n_value == CompareConst.NONE:
256
- return 0, ""
257
- if n_value == CompareConst.SHAPE_UNMATCH:
258
- return CompareConst.SHAPE_UNMATCH, ""
259
- if n_value == CompareConst.NAN:
260
- return "N/A", ""
243
+ msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
244
+ return CompareConst.NAN, msg
245
+ return format_value(max_relative_err), ""
261
246
 
262
- if not n_value.shape:
263
- return CompareConst.NAN, ""
264
- if relative_err is None:
265
- relative_err = get_relative_err(n_value, b_value)
266
- if not np.size(relative_err):
267
- return CompareConst.NAN, ""
268
- return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
269
-
270
-
271
- class GetFiveThousandErrRatio(TensorComparisonBasic):
272
- """计算相对误差小于千分之五的比例"""
273
- def apply(self, n_value, b_value, error_flag, relative_err=None):
274
- if error_flag:
275
- if n_value == CompareConst.READ_NONE:
276
- return CompareConst.NONE, ""
277
- if n_value == CompareConst.NONE:
278
- return 0, ""
279
- if n_value == CompareConst.SHAPE_UNMATCH:
280
- return CompareConst.SHAPE_UNMATCH, ""
281
- if n_value == CompareConst.NAN:
282
- return "N/A", ""
283
247
 
248
+ class GetErrRatio(TensorComparisonBasic):
249
+ """计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
250
+ def __init__(self, threshold):
251
+ self.threshold = threshold
252
+
253
+ def apply(self, n_value, b_value, relative_err):
284
254
  if not n_value.shape:
285
- return CompareConst.NAN, ""
286
- if relative_err is None:
287
- relative_err = get_relative_err(n_value, b_value)
255
+ return CompareConst.UNSUPPORTED, ""
256
+
288
257
  if not np.size(relative_err):
289
258
  return CompareConst.NAN, ""
290
- return format_value(
291
- np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
259
+
260
+ ratio = np.sum(relative_err < self.threshold) / np.size(relative_err)
261
+ return format_value(ratio), ""
292
262
 
293
263
 
294
264
  class CompareOps:
@@ -296,15 +266,36 @@ class CompareOps:
296
266
  "cosine_similarity": GetCosineSimilarity(),
297
267
  "max_abs_error": GetMaxAbsErr(),
298
268
  "max_relative_error": GetMaxRelativeErr(),
299
- "one_thousand_err_ratio": GetThousandErrRatio(),
300
- "five_thousand_err_ratio": GetFiveThousandErrRatio()
269
+ "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
270
+ "five_thousand_err_ratio": GetErrRatio(CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD)
301
271
  }
302
272
 
303
273
 
304
- def compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=None):
274
+ def error_value_process(n_value):
275
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
276
+ return CompareConst.UNSUPPORTED, ""
277
+ if n_value == CompareConst.NONE:
278
+ return 0, ""
279
+ if n_value == CompareConst.SHAPE_UNMATCH:
280
+ return CompareConst.SHAPE_UNMATCH, ""
281
+ if n_value == CompareConst.NAN:
282
+ return CompareConst.N_A, ""
283
+ return CompareConst.N_A, ""
284
+
285
+
286
+ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
305
287
  result_list = []
288
+ if error_flag:
289
+ result, msg = error_value_process(n_value)
290
+ result_list = [result] * len(CompareOps.compare_ops)
291
+ err_msg += msg * len(CompareOps.compare_ops)
292
+ return result_list, err_msg
293
+
294
+ relative_err = get_relative_err(n_value, b_value)
295
+ n_value, b_value = reshape_value(n_value, b_value)
296
+
306
297
  for op in CompareOps.compare_ops.values():
307
- result, msg = op.apply(n_value, b_value, error_flag, relative_err=relative_err)
308
- err_msg += msg
298
+ result, msg = op.apply(n_value, b_value, relative_err)
309
299
  result_list.append(result)
300
+ err_msg += msg
310
301
  return result_list, err_msg