mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +14 -19
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +155 -6
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +3 -0
  10. msprobe/core/common/utils.py +28 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +380 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/multiprocessing_compute.py +2 -2
  22. msprobe/core/compare/npy_compare.py +109 -147
  23. msprobe/core/compare/utils.py +189 -69
  24. msprobe/core/data_dump/data_collector.py +51 -21
  25. msprobe/core/data_dump/data_processor/base.py +38 -20
  26. msprobe/core/data_dump/data_processor/factory.py +5 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
  29. msprobe/core/data_dump/json_writer.py +29 -1
  30. msprobe/core/data_dump/scope.py +19 -18
  31. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  32. msprobe/core/overflow_check/checker.py +1 -1
  33. msprobe/core/overflow_check/utils.py +1 -1
  34. msprobe/docs/01.installation.md +96 -17
  35. msprobe/docs/02.config_introduction.md +5 -5
  36. msprobe/docs/05.data_dump_PyTorch.md +91 -61
  37. msprobe/docs/06.data_dump_MindSpore.md +57 -19
  38. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  39. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
  40. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  41. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  42. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  43. msprobe/docs/19.monitor.md +120 -27
  44. msprobe/docs/21.visualization_PyTorch.md +115 -35
  45. msprobe/docs/22.visualization_MindSpore.md +138 -41
  46. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  47. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  48. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  49. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  50. msprobe/docs/27.dump_json_instruction.md +521 -0
  51. msprobe/docs/FAQ.md +26 -2
  52. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  53. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  54. msprobe/docs/img/merge_result.png +0 -0
  55. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  56. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  57. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  58. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  59. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  60. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  61. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  62. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  63. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  64. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  65. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  66. msprobe/docs/visualization/GPTModel.png +0 -0
  67. msprobe/docs/visualization/ParallelMLP.png +0 -0
  68. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  69. msprobe/docs/visualization/mapping.png +0 -0
  70. msprobe/docs/visualization/mapping1.png +0 -0
  71. msprobe/docs/visualization/module_name.png +0 -0
  72. msprobe/docs/visualization/module_name1.png +0 -0
  73. msprobe/docs/visualization/no_mapping.png +0 -0
  74. msprobe/docs/visualization/no_mapping1.png +0 -0
  75. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  76. msprobe/docs/visualization/top_layer.png +0 -0
  77. msprobe/mindspore/__init__.py +10 -0
  78. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
  79. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  80. msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
  81. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  82. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  83. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  84. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  85. msprobe/mindspore/code_mapping/bind.py +264 -0
  86. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  87. msprobe/mindspore/code_mapping/graph.py +49 -0
  88. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  89. msprobe/mindspore/code_mapping/main.py +24 -0
  90. msprobe/mindspore/code_mapping/processor.py +34 -0
  91. msprobe/mindspore/common/const.py +3 -1
  92. msprobe/mindspore/common/utils.py +50 -5
  93. msprobe/mindspore/compare/distributed_compare.py +0 -2
  94. msprobe/mindspore/compare/ms_compare.py +105 -63
  95. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  96. msprobe/mindspore/debugger/debugger_config.py +3 -0
  97. msprobe/mindspore/debugger/precision_debugger.py +81 -12
  98. msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
  99. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  100. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  101. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  102. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  103. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  104. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  105. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  106. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  107. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  108. msprobe/mindspore/grad_probe/hook.py +13 -4
  109. msprobe/mindspore/mindtorch/__init__.py +18 -0
  110. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  111. msprobe/mindspore/ms_config.py +5 -1
  112. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  113. msprobe/mindspore/service.py +267 -101
  114. msprobe/msprobe.py +24 -3
  115. msprobe/pytorch/__init__.py +7 -6
  116. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  117. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  123. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  124. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
  125. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  126. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  127. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  128. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  129. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  130. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  131. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  132. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  133. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  134. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  135. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  136. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  140. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  141. msprobe/pytorch/common/parse_json.py +2 -1
  142. msprobe/pytorch/common/utils.py +45 -2
  143. msprobe/pytorch/compare/distributed_compare.py +17 -29
  144. msprobe/pytorch/compare/pt_compare.py +40 -20
  145. msprobe/pytorch/debugger/debugger_config.py +27 -12
  146. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  147. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  148. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  149. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
  150. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  151. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  152. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  153. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  154. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  155. msprobe/pytorch/hook_module/__init__.py +1 -1
  156. msprobe/pytorch/hook_module/hook_module.py +14 -11
  157. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  158. msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
  159. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  160. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  161. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  162. msprobe/pytorch/monitor/anomaly_detect.py +107 -22
  163. msprobe/pytorch/monitor/csv2tb.py +166 -0
  164. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  165. msprobe/pytorch/monitor/features.py +3 -3
  166. msprobe/pytorch/monitor/module_hook.py +483 -277
  167. msprobe/pytorch/monitor/module_metric.py +27 -48
  168. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  169. msprobe/pytorch/monitor/optimizer_collect.py +52 -14
  170. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  171. msprobe/pytorch/monitor/utils.py +77 -6
  172. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  173. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  174. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  175. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  176. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  177. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  178. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  179. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  180. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  181. msprobe/pytorch/service.py +176 -106
  182. msprobe/visualization/builder/graph_builder.py +62 -5
  183. msprobe/visualization/builder/msprobe_adapter.py +24 -2
  184. msprobe/visualization/compare/graph_comparator.py +64 -14
  185. msprobe/visualization/compare/mode_adapter.py +1 -15
  186. msprobe/visualization/graph/base_node.py +12 -17
  187. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  188. msprobe/visualization/graph/graph.py +9 -0
  189. msprobe/visualization/graph_service.py +97 -23
  190. msprobe/visualization/utils.py +14 -29
  191. msprobe/pytorch/functional/module_dump.py +0 -84
  192. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  193. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
  194. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
  195. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  196. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  197. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -72,38 +72,53 @@ def check_need_convert(api_name):
72
72
  return convert_type
73
73
 
74
74
 
75
- def api_info_preprocess(api_name, api_info_dict):
75
+ def cross_entropy_process(api_info_dict):
76
76
  """
77
77
  Function Description:
78
- Preprocesses the API information.
78
+ Preprocesses the cross_entropy API information.
79
79
  Parameter:
80
- api_name: Name of the API.
81
80
  api_info_dict: argument of the API.
82
81
  Return api_info_dict:
83
- convert_type: Type of conversion.
84
82
  api_info_dict: Processed argument of the API.
85
83
  """
86
- convert_type = check_need_convert(api_name)
87
- if api_name == 'cross_entropy':
88
- api_info_dict = cross_entropy_process(api_info_dict)
89
- return convert_type, api_info_dict
84
+ if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
85
+ and 'Min' in api_info_dict['input_args'][1]:
86
+ if api_info_dict['input_args'][1]['Min'] <= 0:
87
+ # The second argument in cross_entropy should be -100 or not less than 0
88
+ api_info_dict['input_args'][1]['Min'] = 0
89
+ return api_info_dict
90
90
 
91
91
 
92
- def cross_entropy_process(api_info_dict):
92
+ def histc_process(api_info_dict):
93
+ input_args = api_info_dict['input_args']
94
+ if input_args and input_args[0].get('dtype'):
95
+ dtype = input_args[0]['dtype']
96
+ if dtype in Const.TORCH_INT_DTYPE:
97
+ api_info_dict['input_args'][0]['dtype'] = Const.TORCH_FLOAT32
98
+ return api_info_dict
99
+
100
+
101
+ API_PROCESS_MAP = {
102
+ 'cross_entropy': cross_entropy_process,
103
+ 'histc': histc_process
104
+ }
105
+
106
+
107
+ def api_info_preprocess(api_name, api_info_dict):
93
108
  """
94
109
  Function Description:
95
- Preprocesses the cross_entropy API information.
110
+ Preprocesses the API information.
96
111
  Parameter:
112
+ api_name: Name of the API.
97
113
  api_info_dict: argument of the API.
98
114
  Return api_info_dict:
115
+ convert_type: Type of conversion.
99
116
  api_info_dict: Processed argument of the API.
100
117
  """
101
- if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
102
- and 'Min' in api_info_dict['input_args'][1]:
103
- if api_info_dict['input_args'][1]['Min'] <= 0:
104
- # The second argument in cross_entropy should be -100 or not less than 0
105
- api_info_dict['input_args'][1]['Min'] = 0
106
- return api_info_dict
118
+ convert_type = check_need_convert(api_name)
119
+ if api_name in API_PROCESS_MAP:
120
+ api_info_dict = API_PROCESS_MAP[api_name](api_info_dict)
121
+ return convert_type, api_info_dict
107
122
 
108
123
 
109
124
  def initialize_save_path(save_path, dir_name):
@@ -16,10 +16,12 @@
16
16
  # limitations under the License.
17
17
 
18
18
  # 定义比对算法及比对标准
19
+ import math
19
20
  import torch
20
21
  import numpy as np
21
22
 
22
23
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
24
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
25
  from msprobe.core.common.const import CompareConst
24
26
 
25
27
 
@@ -179,13 +181,13 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
179
181
 
180
182
  def check_small_value(abs_err, small_value_mask, small_value_atol):
181
183
  '''
182
- 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
184
+ 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
183
185
  输入:
184
- rel_err:npu输出和golden输出的相对误差
186
+ abs_err:npu输出和golden输出的绝对误差
185
187
  normal_value_mask:npu输出和golden输出的正常值mask
186
- rtol:相对误差的阈值
188
+ atol:绝对误差的阈值
187
189
  输出:
188
- rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
190
+ abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
189
191
  '''
190
192
  greater_mask = np.greater(abs_err, small_value_atol)
191
193
  err_mask = np.logical_and(greater_mask, small_value_mask)
@@ -195,13 +197,13 @@ def check_small_value(abs_err, small_value_mask, small_value_atol):
195
197
 
196
198
  def check_norm_value(normal_value_mask, rel_err, rtol):
197
199
  '''
198
- 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
200
+ 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
199
201
  输入:
200
- abs_err:npu输出和golden输出的绝对误差
202
+ rel_err:npu输出和golden输出的相对误差
201
203
  normal_value_mask:npu输出和golden输出的正常值mask
202
- atol:绝对误差的阈值
204
+ rtol:相对误差的阈值
203
205
  输出:
204
- abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
206
+ rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
205
207
  '''
206
208
  err_mask = np.greater(rel_err, rtol)
207
209
  err_mask = np.logical_and(err_mask, normal_value_mask)
@@ -228,3 +230,34 @@ def get_ulp_err(bench_output, device_output, dtype):
228
230
  def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
229
231
  return (device_output.astype(data_type) - bench_output).astype(data_type) * \
230
232
  np.exp2(-eb + exponent_num).astype(data_type)
233
+
234
+
235
+ def calc_ratio(x, y, dtype):
236
+ """
237
+ Calculate the ratio between NPU and GPU statistical values.
238
+
239
+ Args:
240
+ x (float): Statistical value from the NPU side
241
+ y (float): Statistical value from the GPU side
242
+ dtype: Data type used to determine the minimum error value
243
+
244
+ Returns:
245
+ float: The ratio of NPU to GPU statistical values
246
+
247
+ Notes:
248
+ - Takes absolute values of both x and y for calculation
249
+ - Uses StandardConfig.get_minmum_err(dtype) to get minimum error for the specified dtype
250
+ - Prevents division by zero by ensuring denominator is not less than minimum error
251
+ - Returns |x| / max(|y|, minimum_error)
252
+ """
253
+ x, y = abs(x), abs(y)
254
+ minmum_err = StandardConfig.get_minmum_err(dtype)
255
+ err_y = max(y, minmum_err)
256
+ return x / err_y
257
+
258
+
259
+ def compare_bool_tensor(bench_output, device_output):
260
+ error_nums = (bench_output != device_output).sum()
261
+ error_rate = float(error_nums / bench_output.size)
262
+ result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
263
+ return error_rate, result, ""
@@ -29,11 +29,15 @@ from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
29
29
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
30
30
  API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
31
31
  ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
32
- BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
33
- check_inf_or_nan
32
+ BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage
33
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_input import PrecisionCompareInput
34
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry
35
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpPrecisionCompare
36
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkPrecisionCompare
37
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
34
38
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
35
39
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
36
- from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments
40
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments
37
41
  from msprobe.core.common.file_utils import FileChecker, change_mode, create_directory
38
42
  from msprobe.pytorch.common.log import logger
39
43
  from msprobe.core.common.utils import CompareException
@@ -47,30 +51,6 @@ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_va
47
51
  'eb_inf_nan_consistency'])
48
52
  UNSUPPORTED_MESSAGE = 'This data type does not support benchmark compare.'
49
53
 
50
- DEFAULT_THRESHOLD = 1
51
-
52
- benchmark_algorithms_thresholds = {
53
- 'small_value': {
54
- 'error_threshold': 2,
55
- 'warning_threshold': 1
56
- },
57
- 'rmse': {
58
- 'error_threshold': 2,
59
- 'warning_threshold': 1
60
- },
61
- 'max_rel_err': {
62
- 'error_threshold': 10,
63
- 'warning_threshold': 1
64
- },
65
- 'mean_rel_err': {
66
- 'error_threshold': 2,
67
- 'warning_threshold': 1
68
- },
69
- 'eb': {
70
- 'error_threshold': 2,
71
- 'warning_threshold': 1
72
- }
73
- }
74
54
 
75
55
  benchmark_message = {
76
56
  "small_value_err_status": {
@@ -92,189 +72,6 @@ benchmark_message = {
92
72
  }
93
73
 
94
74
 
95
- class Standard:
96
- @staticmethod
97
- def _calc_ratio(column_name, x, y, default_value):
98
- '''
99
- 计算npu侧和gpu侧统计量的比值
100
- 输入:
101
- column_name:统计量名称
102
- x:npu侧统计量
103
- y:gpu侧统计量
104
- default:当x不接近0,y接近0,设置的比值默认值
105
- 输出:
106
- ratio:统计量x和y的比值
107
- inf_nan_consistency:不出现inf或nan时为True,出现inf或nan时必须同时为inf或-inf或nan才为True,否则为False
108
- message:当出现inf或nan时的提示信息
109
- '''
110
- x, y = convert_str_to_float(x), convert_str_to_float(y)
111
-
112
- if is_inf_or_nan(x) or is_inf_or_nan(y):
113
- return check_inf_or_nan(x, y, column_name)
114
-
115
- inf_nan_consistency = True
116
- message = ""
117
- if math.isclose(y, 0.0):
118
- if math.isclose(x, 0.0):
119
- return 1.0, inf_nan_consistency, message
120
- else:
121
- return default_value, inf_nan_consistency, message
122
- else:
123
- return abs(x / y), inf_nan_consistency, message
124
-
125
-
126
- class BenchmarkStandard(Standard):
127
- def __init__(self, api_name, npu_precision, gpu_precision):
128
- self.api_name = api_name
129
- self.npu_precision = npu_precision
130
- self.gpu_precision = gpu_precision
131
- self.small_value_err_ratio = 1
132
- self.rmse_ratio = 1
133
- self.max_rel_err_ratio = 1
134
- self.mean_rel_err_ratio = 1
135
- self.eb_ratio = 1
136
- self.small_value_err_status = CompareConst.PASS
137
- self.rmse_status = CompareConst.PASS
138
- self.max_rel_err_status = CompareConst.PASS
139
- self.mean_rel_err_status = CompareConst.PASS
140
- self.eb_status = CompareConst.PASS
141
- self.check_result_list = []
142
- self.final_result = CompareConst.PASS
143
- self.compare_message = ""
144
-
145
- def __str__(self):
146
- return "%s" % (self.api_name)
147
-
148
- @staticmethod
149
- def _get_status(ratio, algorithm):
150
- if math.isnan(ratio) or math.isinf(ratio):
151
- return CompareConst.PASS
152
- error_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('error_threshold', DEFAULT_THRESHOLD)
153
- warning_threshold = benchmark_algorithms_thresholds.get(algorithm, {}).get('warning_threshold',
154
- DEFAULT_THRESHOLD)
155
- if ratio > error_threshold:
156
- return CompareConst.ERROR
157
- elif ratio > warning_threshold:
158
- return CompareConst.WARNING
159
- return CompareConst.PASS
160
-
161
- def get_result(self):
162
- inf_nan_consistency = self._compare_ratio()
163
- small_value_inf_nan_consistency = inf_nan_consistency.small_value_inf_nan_consistency
164
- rmse_inf_nan_consistency = inf_nan_consistency.rmse_inf_nan_consistency
165
- max_rel_inf_nan_consistency = inf_nan_consistency.max_rel_inf_nan_consistency
166
- mean_rel_inf_nan_consistency = inf_nan_consistency.mean_rel_inf_nan_consistency
167
- eb_inf_nan_consistency = inf_nan_consistency.eb_inf_nan_consistency
168
- self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') if \
169
- small_value_inf_nan_consistency else CompareConst.ERROR
170
- self.check_result_list.append(self.small_value_err_status)
171
- self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') if rmse_inf_nan_consistency \
172
- else CompareConst.ERROR
173
- self.check_result_list.append(self.rmse_status)
174
- self.max_rel_err_status = self._get_status(
175
- self.max_rel_err_ratio, 'max_rel_err') if max_rel_inf_nan_consistency else CompareConst.ERROR
176
- self.check_result_list.append(self.max_rel_err_status)
177
- self.mean_rel_err_status = self._get_status(
178
- self.mean_rel_err_ratio, 'mean_rel_err') if mean_rel_inf_nan_consistency else CompareConst.ERROR
179
- self.check_result_list.append(self.mean_rel_err_status)
180
- self.eb_status = self._get_status(self.eb_ratio, 'eb')
181
- if CompareConst.ERROR in self.check_result_list:
182
- self.final_result = CompareConst.ERROR
183
- elif CompareConst.WARNING in self.check_result_list:
184
- self.final_result = CompareConst.WARNING
185
-
186
- def to_column_value(self):
187
- return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio,
188
- self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio,
189
- self.mean_rel_err_status, self.eb_ratio, self.eb_status]
190
-
191
- def _compare_ratio(self):
192
-
193
- self.small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = self._calc_ratio(
194
- ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE,
195
- self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE),
196
- self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0)
197
- self.compare_message += small_value_message
198
- self.rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE,
199
- self.npu_precision.get(ApiPrecisionCompareColumn.RMSE),
200
- self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0)
201
- self.compare_message += rmse_message
202
- self.max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = self._calc_ratio(
203
- ApiPrecisionCompareColumn.MAX_REL_ERR,
204
- self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR),
205
- self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0)
206
- self.compare_message += max_rel_message
207
- self.mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = self._calc_ratio(
208
- ApiPrecisionCompareColumn.MEAN_REL_ERR,
209
- self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR),
210
- self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0)
211
- self.compare_message += mean_rel_message
212
- self.eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB,
213
- self.npu_precision.get(ApiPrecisionCompareColumn.EB),
214
- self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0)
215
- self.compare_message += eb_message
216
-
217
- return BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
218
- max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
219
- eb_inf_nan_consistency)
220
-
221
-
222
- class ULPStandard(Standard):
223
- def __init__(self, api_name, npu_precision, gpu_precision):
224
- self.api_name = api_name
225
- self.npu_precision = npu_precision
226
- self.gpu_precision = gpu_precision
227
- self.mean_ulp_err = 0
228
- self.ulp_err_proportion = 0
229
- self.ulp_err_proportion_ratio = 1
230
- self.ulp_err_status = CompareConst.PASS
231
- self.compare_message = ""
232
-
233
- def __str__(self):
234
- return f"{self.api_name}"
235
-
236
- def get_result(self):
237
- self.mean_ulp_err = convert_str_to_float(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
238
- gpu_mean_ulp_err = convert_str_to_float(self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_ULP_ERR))
239
- inf_nan_consistency = True
240
- if is_inf_or_nan(self.mean_ulp_err) or is_inf_or_nan(gpu_mean_ulp_err):
241
- _, inf_nan_consistency, message = check_inf_or_nan(self.mean_ulp_err, gpu_mean_ulp_err,
242
- ApiPrecisionCompareColumn.MEAN_ULP_ERR)
243
- self.compare_message += message
244
- self.ulp_err_proportion = convert_str_to_float(
245
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION))
246
- self.ulp_err_proportion_ratio, ulp_inf_nan_consistency, message = self._calc_ratio(
247
- ApiPrecisionCompareColumn.ULP_ERR_PROPORTION,
248
- self.npu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION),
249
- self.gpu_precision.get(ApiPrecisionCompareColumn.ULP_ERR_PROPORTION), 10000.0)
250
- inf_nan_consistency = inf_nan_consistency and ulp_inf_nan_consistency
251
- self.compare_message += message
252
- if inf_nan_consistency:
253
- self.ulp_err_status = self._get_ulp_status(self.npu_precision.get(ApiPrecisionCompareColumn.DEVICE_DTYPE))
254
- else:
255
- self.ulp_err_status = CompareConst.ERROR
256
-
257
- def _get_ulp_status(self, dtype):
258
- if dtype == torch.float32:
259
- if self.mean_ulp_err < 64:
260
- return CompareConst.PASS
261
- elif self.ulp_err_proportion < 0.05:
262
- return CompareConst.PASS
263
- elif self.ulp_err_proportion_ratio < 1:
264
- return CompareConst.PASS
265
- else:
266
- self.compare_message += "ERROR: ULP误差不满足标准\n"
267
- return CompareConst.ERROR
268
- else:
269
- if self.ulp_err_proportion < 0.001:
270
- return CompareConst.PASS
271
- elif self.ulp_err_proportion_ratio < 1:
272
- return CompareConst.PASS
273
- else:
274
- self.compare_message += "ERROR: ULP误差不满足标准\n"
275
- return CompareConst.ERROR
276
-
277
-
278
75
  def write_detail_csv(content, save_path):
279
76
  rows = []
280
77
  content = ["{:.{}f}".format(item, msCheckerConfig.precision) \
@@ -283,6 +80,17 @@ def write_detail_csv(content, save_path):
283
80
  write_csv(rows, save_path)
284
81
 
285
82
 
83
+ def register_compare_func():
84
+ registry = StandardRegistry()
85
+ registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result)
86
+ registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result)
87
+ registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result)
88
+ registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result)
89
+ registry.register(CompareConst.BENCHMARK, record_benchmark_compare_result)
90
+ registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result)
91
+ return registry
92
+
93
+
286
94
  def api_precision_compare(config):
287
95
  logger.info("Start compare task")
288
96
  logger.info(f"Compare task result will be saved in {config.result_csv_path}")
@@ -337,6 +145,8 @@ def analyse_csv(npu_data, gpu_data, config):
337
145
  forward_status, backward_status = [], []
338
146
  last_api_name, last_api_dtype, last_api_full_name = None, None, None
339
147
  last_api_skip_message = ''
148
+ registry = register_compare_func()
149
+
340
150
  for _, row_npu in npu_data.iterrows():
341
151
  message = ''
342
152
  compare_column = ApiPrecisionOutputColumn()
@@ -362,7 +172,7 @@ def analyse_csv(npu_data, gpu_data, config):
362
172
  row_gpu = row_gpu.iloc[0]
363
173
  new_status = CompareConst.SPACE
364
174
  try:
365
- new_status = get_api_status(row_npu, row_gpu, api_name, compare_column)
175
+ new_status = get_api_status(row_npu, row_gpu, api_name, compare_column, registry)
366
176
  except Exception as err:
367
177
  logger.error(f"Get api status error: {str(err)}")
368
178
  compare_column.api_name = full_api_name_with_direction_status
@@ -383,7 +193,8 @@ def analyse_csv(npu_data, gpu_data, config):
383
193
  else:
384
194
  forward_result = get_api_checker_result(forward_status)
385
195
  backward_result = get_api_checker_result(backward_status)
386
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
196
+ _, base_api_name = extract_basic_api_segments(last_api_name)
197
+ message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
387
198
  message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
388
199
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
389
200
  print_test_success(last_api_name, forward_result, backward_result)
@@ -415,37 +226,30 @@ def analyse_csv(npu_data, gpu_data, config):
415
226
  else:
416
227
  forward_result = get_api_checker_result(forward_status)
417
228
  backward_result = get_api_checker_result(backward_status)
418
- message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
229
+ _, base_api_name = extract_basic_api_segments(last_api_name)
230
+ message += CompareMessage.get(base_api_name, "") if forward_result == CompareConst.ERROR else ""
419
231
  message += last_api_skip_message if forward_result == CompareConst.SKIP else ""
420
232
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
421
233
  print_test_success(last_api_name, forward_result, backward_result)
422
234
  last_api_skip_message = ''
423
235
 
424
236
 
425
- def get_api_status(row_npu, row_gpu, api_name, compare_column):
237
+ def get_api_status(row_npu, row_gpu, api_name, compare_column, registry):
426
238
  full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
427
239
  # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对
428
- if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace():
240
+ if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace() or \
241
+ row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in API_PRECISION_COMPARE_UNSUPPORT_LIST or \
242
+ row_npu[ApiPrecisionCompareColumn.SHAPE] == CompareConst.ZERO_SHAPE:
429
243
  compare_column.api_name = full_api_name_with_direction_status
430
244
  compare_column.compare_result = CompareConst.SKIP
431
245
  compare_column.compare_message = row_npu[ApiPrecisionCompareColumn.MESSAGE]
432
246
  new_status = CompareConst.SKIP
433
247
  else:
434
248
  compare_column.api_name = full_api_name_with_direction_status
435
- if api_name in thousandth_standard_api:
436
- new_status = record_thousandth_threshold_result(compare_column, row_npu)
437
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
438
- api_name in binary_standard_api:
439
- new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
440
- elif api_name in absolute_standard_api:
441
- new_status = record_absolute_threshold_result(compare_column, row_npu)
442
- elif api_name in ulp_standard_api and \
443
- row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
444
- us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
445
- new_status = record_ulp_compare_result(compare_column, us)
446
- elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST:
447
- bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu)
448
- new_status = record_benchmark_compare_result(compare_column, bs)
249
+ dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
250
+ input_data = PrecisionCompareInput(row_npu, row_gpu, dtype, compare_column)
251
+ comparison_func = registry.get_comparison_function(api_name, dtype)
252
+ new_status = comparison_func(input_data)
449
253
  return new_status
450
254
 
451
255
 
@@ -505,21 +309,24 @@ def check_csv_columns(columns, csv_type):
505
309
  raise CompareException(CompareException.INVALID_DATA_ERROR, msg)
506
310
 
507
311
 
508
- def record_binary_consistency_result(api_name, compare_column, row_npu):
312
+ def record_binary_consistency_result(input_data):
313
+ row_npu = input_data.row_npu
314
+ compare_column = input_data.compare_column
509
315
  new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE])
510
316
  compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE]
511
317
  compare_column.error_rate_status = new_status
512
318
  compare_column.compare_result = new_status
513
- compare_column.compare_algorithm = "二进制一致法"
319
+ compare_column.compare_algorithm = CompareConst.BINARY_CONSISTENCY_ALGORITHM_NAME
514
320
  message = ''
515
321
  if compare_column.error_rate_status == CompareConst.ERROR:
516
322
  message += "ERROR: 二进制一致错误率超过阈值\n"
517
- message += CompareMessage.get(api_name, "")
518
323
  compare_column.compare_message = message
519
324
  return new_status
520
325
 
521
326
 
522
- def record_absolute_threshold_result(compare_column, row_npu):
327
+ def record_absolute_threshold_result(input_data):
328
+ row_npu = input_data.row_npu
329
+ compare_column = input_data.compare_column
523
330
  absolute_threshold_result = get_absolute_threshold_result(row_npu)
524
331
  compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
525
332
  compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
@@ -528,62 +335,88 @@ def record_absolute_threshold_result(compare_column, row_npu):
528
335
  compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
529
336
  compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
530
337
  compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result")
531
- compare_column.compare_algorithm = "绝对阈值法"
338
+ compare_column.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD_ALGORITHM_NAME
532
339
  message = ''
533
340
  if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
534
- message += "ERROR: inf/nan错误率超过阈值\n"
341
+ message += "ERROR: inf/nan错误率超过阈值"
535
342
  if compare_column.rel_err_ratio_status == CompareConst.ERROR:
536
- message += "ERROR: 相对误差错误率超过阈值\n"
343
+ message += "ERROR: 相对误差错误率超过阈值"
537
344
  if compare_column.abs_err_ratio_status == CompareConst.ERROR:
538
- message += "ERROR: 绝对误差错误率超过阈值\n"
345
+ message += "ERROR: 绝对误差错误率超过阈值"
539
346
  compare_column.compare_message = message
540
347
  return compare_column.compare_result
541
348
 
542
349
 
543
- def record_benchmark_compare_result(compare_column, bs):
544
- bs.get_result()
545
- compare_column.small_value_err_ratio = bs.small_value_err_ratio
546
- compare_column.small_value_err_status = bs.small_value_err_status
547
- compare_column.rmse_ratio = bs.rmse_ratio
548
- compare_column.rmse_status = bs.rmse_status
549
- compare_column.max_rel_err_ratio = bs.max_rel_err_ratio
550
- compare_column.max_rel_err_status = bs.max_rel_err_status
551
- compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio
552
- compare_column.mean_rel_err_status = bs.mean_rel_err_status
553
- compare_column.eb_ratio = bs.eb_ratio
554
- compare_column.eb_status = bs.eb_status
555
- compare_column.compare_result = bs.final_result
556
- compare_column.compare_algorithm = "标杆比对法"
557
- compare_column.compare_message = bs.compare_message
350
+ def record_benchmark_compare_result(input_data):
351
+ bs = BenchmarkPrecisionCompare(input_data)
352
+ compare_result = bs.compare()
558
353
  for status_attr, messages in benchmark_message.items():
559
- status_value = getattr(compare_column, status_attr)
354
+ status_value = getattr(input_data.compare_column, status_attr)
560
355
  if status_value in messages:
561
- compare_column.compare_message += messages[status_value]
562
- return compare_column.compare_result
356
+ input_data.compare_column.compare_message += messages[status_value]
357
+ return compare_result
358
+
359
+
360
+ def record_ulp_compare_result(input_data):
361
+ us = UlpPrecisionCompare(input_data)
362
+ compare_result = us.compare()
363
+ return compare_result
563
364
 
564
365
 
565
- def record_ulp_compare_result(compare_column, us):
566
- us.get_result()
567
- compare_column.mean_ulp_err = us.mean_ulp_err
568
- compare_column.ulp_err_proportion = us.ulp_err_proportion
569
- compare_column.ulp_err_proportion_ratio = us.ulp_err_proportion_ratio
570
- compare_column.ulp_err_status = us.ulp_err_status
571
- compare_column.compare_result = us.ulp_err_status
572
- compare_column.compare_algorithm = "ULP误差比对法"
573
- compare_column.compare_message = us.compare_message
366
+ def record_accumulative_error_compare_result(input_data):
367
+ row_npu = input_data.row_npu
368
+ compare_column = input_data.compare_column
369
+ absolute_threshold_result = get_absolute_threshold_result(row_npu)
370
+ threshold_result = absolute_threshold_result.get("absolute_threshold_result")
371
+ eb, eb_result = check_eb(row_npu)
372
+ accumulative_error_compare_result = CompareConst.PASS
373
+ if CompareConst.ERROR in [threshold_result, eb_result]:
374
+ accumulative_error_compare_result = CompareConst.ERROR
375
+
376
+ compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio")
377
+ compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result")
378
+ compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio")
379
+ compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result")
380
+ compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio")
381
+ compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result")
382
+ compare_column.eb_ratio = eb
383
+ compare_column.eb_status = eb_result
384
+ compare_column.compare_result = accumulative_error_compare_result
385
+ compare_column.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME
386
+ message = []
387
+ if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR:
388
+ message.append("ERROR: inf/nan错误率超过阈值\n")
389
+ if compare_column.rel_err_ratio_status == CompareConst.ERROR:
390
+ message.append("ERROR: 相对误差错误率超过阈值\n")
391
+ if compare_column.abs_err_ratio_status == CompareConst.ERROR:
392
+ message.append("ERROR: 绝对误差错误率超过阈值\n")
393
+ if compare_column.eb_status == CompareConst.ERROR:
394
+ message.append("ERROR: 误差均衡性超过阈值\n")
395
+ compare_column.compare_message = "\n".join(message)
574
396
  return compare_column.compare_result
575
397
 
576
398
 
399
+ def check_eb(row_npu):
400
+ eb = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.EB])
401
+ dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
402
+ eb_threshold = StandardConfig.get_accumulative_error_eb_threshold(dtype)
403
+ eb_result = CompareConst.PASS if eb <= eb_threshold else CompareConst.ERROR
404
+ return eb, eb_result
405
+
406
+
577
407
  def check_thousandth_rate(thousandth_rate):
578
- return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= 0.999 else CompareConst.ERROR
408
+ return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= CompareConst.THOUSANDTH_PASS_VALUE \
409
+ else CompareConst.ERROR
579
410
 
580
411
 
581
- def record_thousandth_threshold_result(compare_column, row_npu):
412
+ def record_thousandth_threshold_result(input_data):
413
+ row_npu = input_data.row_npu
414
+ compare_column = input_data.compare_column
582
415
  new_status = check_thousandth_rate(row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH])
583
416
  compare_column.rel_err_thousandth = row_npu[ApiPrecisionCompareColumn.REL_ERR_THOUSANDTH]
584
417
  compare_column.rel_err_thousandth_status = new_status
585
418
  compare_column.compare_result = new_status
586
- compare_column.compare_algorithm = "双千指标法"
419
+ compare_column.compare_algorithm = CompareConst.THOUSANDTH_STANDARD_ALGORITHM_NAME
587
420
  message = ''
588
421
  if compare_column.rel_err_thousandth_status == CompareConst.ERROR:
589
422
  message += "ERROR: 双千指标不达标\n"
@@ -66,6 +66,7 @@ BinaryCompareStandard:
66
66
  - greater_
67
67
  - greater_equal
68
68
  - greater_equal_
69
+ - histc
69
70
  - isfinite
70
71
  - isnan
71
72
  - less
@@ -130,4 +131,6 @@ ULPStandard:
130
131
  ThousandthStandard:
131
132
  - conv1d
132
133
  - conv2d
133
-
134
+
135
+ AccumulativeErrorStandard:
136
+ - test_api