mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -16,11 +16,10 @@
16
16
  import abc
17
17
 
18
18
  import numpy as np
19
- from msprobe.core.common.utils import format_value
19
+
20
20
  from msprobe.core.common.const import Const, CompareConst
21
21
  from msprobe.core.common.log import logger
22
-
23
- from msprobe.core.common.utils import CompareException
22
+ from msprobe.core.common.utils import CompareException, format_value
24
23
 
25
24
 
26
25
  def handle_inf_nan(n_value, b_value):
@@ -53,66 +52,66 @@ def handle_inf_nan(n_value, b_value):
53
52
  return n_value, b_value
54
53
 
55
54
 
56
- def get_error_type(n_value, b_value, error_flag):
57
- """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag"""
55
+ def get_error_flag_and_msg(n_value, b_value, error_flag=False, error_file=None):
56
+ """判断数据是否有异常并返回异常的n_value, b_value,同时返回error_flag和error_msg"""
57
+ err_msg = ""
58
58
  if error_flag:
59
- return CompareConst.READ_NONE, CompareConst.READ_NONE, True
59
+ if error_file == "no_bench_data":
60
+ err_msg = "Bench does not have data file."
61
+ elif error_file:
62
+ err_msg = f"Dump file: {error_file} not found."
63
+ else:
64
+ err_msg = CompareConst.NO_BENCH
65
+ error_flag = True
66
+ return CompareConst.READ_NONE, CompareConst.READ_NONE, error_flag, err_msg
67
+
60
68
  if n_value.size == 0: # 判断读取到的数据是否为空
61
- return CompareConst.NONE, CompareConst.NONE, True
69
+ err_msg = "This is empty data, can not compare."
70
+ error_flag = True
71
+ return CompareConst.NONE, CompareConst.NONE, error_flag, err_msg
72
+ if not n_value.shape: # 判断数据是否为0维张量
73
+ err_msg = (f"This is type of 0-d tensor, can not calculate '{CompareConst.COSINE}', "
74
+ f"'{CompareConst.ONE_THOUSANDTH_ERR_RATIO}' and '{CompareConst.FIVE_THOUSANDTHS_ERR_RATIO}'. ")
75
+ error_flag = False # 0-d tensor 最大绝对误差、最大相对误差仍然支持计算,因此error_flag设置为False,不做统一处理
76
+ return n_value, b_value, error_flag, err_msg
62
77
  if n_value.shape != b_value.shape: # 判断NPU和bench的数据结构是否一致
63
- return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, True
64
- if not n_value.shape: # 判断数据是否为标量
65
- return n_value, b_value, False
78
+ err_msg = "Shape of NPU and bench tensor do not match. Skipped."
79
+ error_flag = True
80
+ return CompareConst.SHAPE_UNMATCH, CompareConst.SHAPE_UNMATCH, error_flag, err_msg
66
81
 
67
82
  try:
68
83
  n_value, b_value = handle_inf_nan(n_value, b_value) # 判断是否有nan/inf数据
69
84
  except CompareException:
70
85
  logger.error('Numpy data is unreadable, please check!')
71
- return CompareConst.UNREADABLE, CompareConst.UNREADABLE, True
86
+ err_msg = "Data is unreadable."
87
+ error_flag = True
88
+ return CompareConst.UNREADABLE, CompareConst.UNREADABLE, error_flag, err_msg
72
89
  if n_value is CompareConst.NAN or b_value is CompareConst.NAN:
73
- return CompareConst.NAN, CompareConst.NAN, True
74
- return n_value, b_value, False
90
+ err_msg = "The position of inf or nan in NPU and bench Tensor do not match."
91
+ error_flag = True
92
+ return CompareConst.NAN, CompareConst.NAN, error_flag, err_msg
93
+
94
+ if n_value.dtype != b_value.dtype: # 判断数据的dtype是否一致
95
+ err_msg = "Dtype of NPU and bench tensor do not match."
96
+ error_flag = False
97
+ return n_value, b_value, error_flag, err_msg
98
+
99
+ return n_value, b_value, error_flag, err_msg
75
100
 
76
101
 
77
102
  def reshape_value(n_value, b_value):
78
103
  """返回reshape后的数据"""
79
- if not n_value.shape: # 判断数据是否为标量
104
+ if not n_value.shape: # 判断数据是否为0维tensor, 如果0维tensor,不会转成1维tensor,直接返回
80
105
  if n_value.dtype == bool:
81
106
  n_value = n_value.astype(float)
82
107
  b_value = b_value.astype(float)
83
108
  return n_value, b_value
84
109
 
85
- n_value = n_value.reshape(-1).astype(float)
110
+ n_value = n_value.reshape(-1).astype(float) # 32转64为了防止某些数转dataframe时出现误差
86
111
  b_value = b_value.reshape(-1).astype(float)
87
112
  return n_value, b_value
88
113
 
89
114
 
90
- def get_error_message(n_value, b_value, npu_op_name, error_flag, error_file=None):
91
- """获取异常情况的错误信息"""
92
- if error_flag:
93
- if n_value == CompareConst.READ_NONE:
94
- if error_file == 'no_bench_data':
95
- return 'Bench does not have data file.'
96
- elif error_file is not None:
97
- return "Dump file: {} not found.".format(error_file)
98
- return CompareConst.NO_BENCH
99
- if n_value == CompareConst.NONE:
100
- return "This is empty data, can not compare."
101
- if n_value == CompareConst.SHAPE_UNMATCH:
102
- return "Shape of NPU and bench Tensor do not match. Skipped."
103
- if n_value == CompareConst.NAN:
104
- return "The position of inf or nan in NPU and bench Tensor do not match."
105
- if n_value == CompareConst.UNREADABLE:
106
- return "The npy data is unable to be read or compared, please check dump data files."
107
- else:
108
- if not n_value.shape:
109
- return "This is type of scalar data, can not compare."
110
- if n_value.dtype != b_value.dtype:
111
- logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(npu_op_name))
112
- return "Dtype of NPU and bench Tensor do not match."
113
- return ""
114
-
115
-
116
115
  def npy_data_check(n_value, b_value):
117
116
  error_message = ""
118
117
  if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray):
@@ -170,10 +169,25 @@ def statistics_data_check(result_dict):
170
169
  class TensorComparisonBasic(abc.ABC):
171
170
  """NPU和bench中npy数据的比较模板"""
172
171
  @abc.abstractmethod
173
- def apply(self, n_value, b_value, error_flag, relative_err=None):
172
+ def apply(self, n_value, b_value, relative_err):
174
173
  raise NotImplementedError
175
174
 
176
175
 
176
+ def get_relative_err(n_value, b_value):
177
+ """计算相对误差"""
178
+ with np.errstate(divide='ignore', invalid='ignore'):
179
+ if b_value.dtype not in CompareConst.FLOAT_TYPE:
180
+ n_value, b_value = n_value.astype(float), b_value.astype(float)
181
+
182
+ n_value_copy = n_value.copy()
183
+ b_value_copy = b_value.copy()
184
+ zero_mask = (b_value_copy == 0)
185
+ b_value_copy[zero_mask] += Const.FLOAT_EPSILON
186
+ n_value_copy[zero_mask] += Const.FLOAT_EPSILON
187
+ relative_err = np.divide((n_value_copy - b_value_copy), b_value_copy)
188
+ return np.abs(relative_err)
189
+
190
+
177
191
  class GetCosineSimilarity(TensorComparisonBasic):
178
192
  """计算cosine相似度"""
179
193
  @staticmethod
@@ -184,140 +198,67 @@ class GetCosineSimilarity(TensorComparisonBasic):
184
198
  return round(float(result), 6)
185
199
  return result
186
200
 
187
- def apply(self, n_value, b_value, error_flag, relative_err=None):
188
- if error_flag:
189
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
190
- return CompareConst.UNSUPPORTED, ''
191
- if n_value == CompareConst.NONE:
192
- return CompareConst.UNSUPPORTED, ''
193
- if n_value == CompareConst.SHAPE_UNMATCH:
194
- return CompareConst.SHAPE_UNMATCH, ''
195
- if n_value == CompareConst.NAN:
196
- return CompareConst.N_A, ''
197
-
201
+ def apply(self, n_value, b_value, relative_err):
198
202
  if not n_value.shape:
199
- return CompareConst.UNSUPPORTED, ''
203
+ return CompareConst.UNSUPPORTED, ""
200
204
 
201
- with np.errstate(divide='ignore', invalid='ignore'):
205
+ with np.errstate(divide="ignore", invalid="ignore"):
202
206
  if len(n_value) == 1:
203
- return CompareConst.UNSUPPORTED, "This tensor is scalar."
207
+ return CompareConst.UNSUPPORTED, "This is a 1-d tensor of length 1."
204
208
  num = n_value.dot(b_value)
205
209
  a_norm = np.linalg.norm(n_value)
206
210
  b_norm = np.linalg.norm(b_value)
207
211
 
208
212
  if a_norm <= Const.FLOAT_EPSILON and b_norm <= Const.FLOAT_EPSILON:
209
- return 1.0, ''
213
+ return 1.0, ""
210
214
  if a_norm <= Const.FLOAT_EPSILON:
211
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in npu dump data.'
215
+ return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in npu dump data."
212
216
  if b_norm <= Const.FLOAT_EPSILON:
213
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data.'
217
+ return CompareConst.NAN, "Cannot compare by Cosine Similarity, All the data is Zero in Bench dump data."
214
218
 
215
219
  cos = num / (a_norm * b_norm)
216
220
  if np.isnan(cos):
217
- return CompareConst.NAN, 'Cannot compare by Cosine Similarity, the dump data has NaN.'
221
+ return CompareConst.NAN, "Cannot compare by Cosine Similarity, the dump data has NaN."
218
222
  result = format_value(cos)
219
223
  result = self.correct_data(result)
220
- return 1.0 if float(result) > 0.99999 else result, ''
224
+ return result, ""
221
225
 
222
226
 
223
227
  class GetMaxAbsErr(TensorComparisonBasic):
224
228
  """计算最大绝对误差"""
225
- def apply(self, n_value, b_value, error_flag, relative_err=None):
226
- if error_flag:
227
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
228
- return CompareConst.UNSUPPORTED, ""
229
- if n_value == CompareConst.NONE:
230
- return 0, ""
231
- if n_value == CompareConst.SHAPE_UNMATCH:
232
- return CompareConst.SHAPE_UNMATCH, ""
233
- if n_value == CompareConst.NAN:
234
- return CompareConst.N_A, ""
235
-
229
+ def apply(self, n_value, b_value, relative_err):
236
230
  temp_res = n_value - b_value
237
231
  max_value = np.max(np.abs(temp_res))
238
232
  if np.isnan(max_value):
239
- message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
240
- return CompareConst.NAN, message
233
+ msg = "Cannot compare by MaxAbsError, the data contains nan/inf/-inf in dump data."
234
+ return CompareConst.NAN, msg
241
235
  return format_value(max_value), ""
242
236
 
243
237
 
244
- def get_relative_err(n_value, b_value):
245
- """计算相对误差"""
246
- with np.errstate(divide='ignore', invalid='ignore'):
247
- if b_value.dtype not in CompareConst.FLOAT_TYPE:
248
- n_value, b_value = n_value.astype(float), b_value.astype(float)
249
- zero_mask = (b_value == 0)
250
- b_value[zero_mask] += np.finfo(b_value.dtype).eps
251
- n_value[zero_mask] += np.finfo(b_value.dtype).eps
252
- relative_err = np.divide((n_value - b_value), b_value)
253
- return np.abs(relative_err)
254
-
255
-
256
238
  class GetMaxRelativeErr(TensorComparisonBasic):
257
239
  """计算最大相对误差"""
258
- def apply(self, n_value, b_value, error_flag, relative_err=None):
259
- if error_flag:
260
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
261
- return CompareConst.UNSUPPORTED, ''
262
- if n_value == CompareConst.NONE:
263
- return 0, ''
264
- if n_value == CompareConst.SHAPE_UNMATCH:
265
- return CompareConst.SHAPE_UNMATCH, ''
266
- if n_value == CompareConst.NAN:
267
- return CompareConst.N_A, ''
268
-
269
- if relative_err is None:
270
- relative_err = get_relative_err(n_value, b_value)
240
+ def apply(self, n_value, b_value, relative_err):
271
241
  max_relative_err = np.max(np.abs(relative_err))
272
242
  if np.isnan(max_relative_err):
273
- message = 'Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data.'
274
- return CompareConst.NAN, message
275
- return format_value(max_relative_err), ''
276
-
277
-
278
- class GetThousandErrRatio(TensorComparisonBasic):
279
- """计算相对误差小于千分之一的比例"""
280
- def apply(self, n_value, b_value, error_flag, relative_err=None):
281
- if error_flag:
282
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
283
- return CompareConst.UNSUPPORTED, ""
284
- if n_value == CompareConst.NONE:
285
- return 0, ""
286
- if n_value == CompareConst.SHAPE_UNMATCH:
287
- return CompareConst.SHAPE_UNMATCH, ""
288
- if n_value == CompareConst.NAN:
289
- return CompareConst.N_A, ""
243
+ msg = "Cannot compare by MaxRelativeError, the data contains nan/inf/-inf in dump data."
244
+ return CompareConst.NAN, msg
245
+ return format_value(max_relative_err), ""
290
246
 
291
- if not n_value.shape:
292
- return CompareConst.NAN, ""
293
- if relative_err is None:
294
- relative_err = get_relative_err(n_value, b_value)
295
- if not np.size(relative_err):
296
- return CompareConst.NAN, ""
297
- return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
298
-
299
-
300
- class GetFiveThousandErrRatio(TensorComparisonBasic):
301
- """计算相对误差小于千分之五的比例"""
302
- def apply(self, n_value, b_value, error_flag, relative_err=None):
303
- if error_flag:
304
- if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
305
- return CompareConst.UNSUPPORTED, ""
306
- if n_value == CompareConst.NONE:
307
- return 0, ""
308
- if n_value == CompareConst.SHAPE_UNMATCH:
309
- return CompareConst.SHAPE_UNMATCH, ""
310
- if n_value == CompareConst.NAN:
311
- return CompareConst.N_A, ""
312
247
 
248
+ class GetErrRatio(TensorComparisonBasic):
249
+ """计算相对误差小于指定阈值(千分之一、千分之五)的比例"""
250
+ def __init__(self, threshold):
251
+ self.threshold = threshold
252
+
253
+ def apply(self, n_value, b_value, relative_err):
313
254
  if not n_value.shape:
314
- return CompareConst.NAN, ""
315
- if relative_err is None:
316
- relative_err = get_relative_err(n_value, b_value)
255
+ return CompareConst.UNSUPPORTED, ""
256
+
317
257
  if not np.size(relative_err):
318
258
  return CompareConst.NAN, ""
319
- return format_value(
320
- np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), ""
259
+
260
+ ratio = np.sum(relative_err < self.threshold) / np.size(relative_err)
261
+ return format_value(ratio), ""
321
262
 
322
263
 
323
264
  class CompareOps:
@@ -325,15 +266,36 @@ class CompareOps:
325
266
  "cosine_similarity": GetCosineSimilarity(),
326
267
  "max_abs_error": GetMaxAbsErr(),
327
268
  "max_relative_error": GetMaxRelativeErr(),
328
- "one_thousand_err_ratio": GetThousandErrRatio(),
329
- "five_thousand_err_ratio": GetFiveThousandErrRatio()
269
+ "one_thousand_err_ratio": GetErrRatio(CompareConst.THOUSAND_RATIO_THRESHOLD),
270
+ "five_thousand_err_ratio": GetErrRatio(CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD)
330
271
  }
331
272
 
332
273
 
333
- def compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=None):
274
+ def error_value_process(n_value):
275
+ if n_value == CompareConst.READ_NONE or n_value == CompareConst.UNREADABLE:
276
+ return CompareConst.UNSUPPORTED, ""
277
+ if n_value == CompareConst.NONE:
278
+ return 0, ""
279
+ if n_value == CompareConst.SHAPE_UNMATCH:
280
+ return CompareConst.SHAPE_UNMATCH, ""
281
+ if n_value == CompareConst.NAN:
282
+ return CompareConst.N_A, ""
283
+ return CompareConst.N_A, ""
284
+
285
+
286
+ def compare_ops_apply(n_value, b_value, error_flag, err_msg):
334
287
  result_list = []
288
+ if error_flag:
289
+ result, msg = error_value_process(n_value)
290
+ result_list = [result] * len(CompareOps.compare_ops)
291
+ err_msg += msg * len(CompareOps.compare_ops)
292
+ return result_list, err_msg
293
+
294
+ relative_err = get_relative_err(n_value, b_value)
295
+ n_value, b_value = reshape_value(n_value, b_value)
296
+
335
297
  for op in CompareOps.compare_ops.values():
336
- result, msg = op.apply(n_value, b_value, error_flag, relative_err=relative_err)
337
- err_msg += msg
298
+ result, msg = op.apply(n_value, b_value, relative_err)
338
299
  result_list.append(result)
300
+ err_msg += msg
339
301
  return result_list, err_msg