mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -0,0 +1,106 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import numpy as np
19
+
20
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \
21
+ check_small_value
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
24
+ from msprobe.core.common.const import CompareConst
25
+
26
+
27
+
28
+ class AbsolutethdCompare(BaseCompare):
29
+ """
30
+ Absolute threshold compare class.
31
+
32
+ This class is used to compare the absolute threshold of benchmark outputs and device outputs.
33
+ It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio
34
+ to determine the accuracy of the device output compared to the benchmark output.
35
+
36
+ Attributes:
37
+ bench_output (np.ndarray): The output from the benchmark.
38
+ device_output (np.ndarray): The output from the device.
39
+ dtype (torch.dtype): The data type of the outputs.
40
+ abs_bench (np.ndarray): The absolute value of the benchmark output.
41
+ abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
42
+ both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
43
+ inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
44
+ rtol (float): The relative tolerance for comparison.
45
+ rel_err (np.ndarray): The relative error between the benchmark and device outputs.
46
+ small_value (float): The small value threshold for comparison.
47
+ small_value_atol (float): The absolute tolerance for small values.
48
+ small_value_mask (np.ndarray): A mask indicating where values are small.
49
+ normal_value_mask (np.ndarray): A mask indicating where values are normal.
50
+
51
+ Methods:
52
+ _get_rtol(): Gets the relative tolerance based on the data type.
53
+ _get_rel_err(abs_bench_with_eps): Calculates the relative error.
54
+ _get_normal_value_mask(small_value_mask): Gets the mask for normal values.
55
+ _pre_compare(): Prepares the comparison by calculating various metrics.
56
+ _compute_metrics(): Computes the comparison metrics.
57
+
58
+ Note:
59
+ This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
60
+ 'compare_column' and 'dtype'.
61
+ The 'dtype' should be a PyTorch data type.
62
+
63
+ See Also:
64
+ BaseCompare: The base class for comparison classes.
65
+ StandardConfig: The class containing standard configuration values.
66
+ """
67
+ def __init__(self, input_data):
68
+ super(AbsolutethdCompare, self).__init__(input_data)
69
+ self.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD
70
+
71
+ def _get_rtol(self):
72
+ return StandardConfig.get_rtol(self.dtype)
73
+
74
+ def _pre_compare(self):
75
+ """
76
+ Prepares the comparison by calculating various metrics.
77
+
78
+ This method performs the following steps:
79
+ 1. Calculates the absolute benchmark values and their epsilon-adjusted versions.
80
+ 2. Determines masks for finite and infinite/NaN values in the outputs.
81
+ 3. Computes the absolute error between benchmark and device outputs.
82
+ 4. Retrieves the relative tolerance based on the data type.
83
+ 5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values.
84
+ 6. Determines the small value threshold and its absolute tolerance.
85
+ 7. Creates a mask for small values based on the benchmark values and finite mask.
86
+ 8. Creates a mask for normal values by excluding small values from the finite mask.
87
+ """
88
+ self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
89
+ self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
90
+ self.abs_err = self.stat_abs_error()
91
+ self.rtol = self._get_rtol()
92
+ self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps)
93
+ self.small_value, self.small_value_atol = self.get_small_value_threshold()
94
+ self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
95
+ self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask)
96
+
97
+ def _compute_metrics(self):
98
+ inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype,
99
+ self.rtol)
100
+ rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.rtol)
101
+ abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.small_value_atol)
102
+ return {
103
+ "inf_nan_error_ratio": inf_nan_error_ratio,
104
+ "rel_err_ratio": rel_err_ratio,
105
+ "abs_err_ratio": abs_err_ratio
106
+ }
@@ -0,0 +1,107 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import numpy as np
19
+
20
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \
21
+ check_small_value, get_error_balance
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
24
+ from msprobe.core.common.const import CompareConst
25
+
26
+
27
+ class AccumulativeErrorCompare(BaseCompare):
28
+ """
29
+ Absolute threshold compare class.
30
+
31
+ This class is used to compare the absolute threshold of benchmark outputs and device outputs.
32
+ It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio
33
+ to determine the accuracy of the device output compared to the benchmark output.
34
+
35
+ Attributes:
36
+ bench_output (np.ndarray): The output from the benchmark.
37
+ device_output (np.ndarray): The output from the device.
38
+ dtype (torch.dtype): The data type of the outputs.
39
+ abs_bench (np.ndarray): The absolute value of the benchmark output.
40
+ abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
41
+ both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
42
+ inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
43
+ bound (float): The tolerance for comparison.
44
+ rel_err (np.ndarray): The relative error between the benchmark and device outputs.
45
+ small_value (float): The small value threshold for comparison.
46
+ small_value_atol (float): The absolute tolerance for small values.
47
+ small_value_mask (np.ndarray): A mask indicating where values are small.
48
+ normal_value_mask (np.ndarray): A mask indicating where values are normal.
49
+
50
+ Methods:
51
+ _get_rtol(): Gets the relative tolerance based on the data type.
52
+ _get_rel_err(abs_bench_with_eps): Calculates the relative error.
53
+ _get_normal_value_mask(small_value_mask): Gets the mask for normal values.
54
+ _pre_compare(): Prepares the comparison by calculating various metrics.
55
+ _compute_metrics(): Computes the comparison metrics.
56
+
57
+ Note:
58
+ This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
59
+ 'compare_column' and 'dtype'.
60
+ The 'dtype' should be a PyTorch data type.
61
+
62
+ See Also:
63
+ BaseCompare: The base class for comparison classes.
64
+ StandardConfig: The class containing standard configuration values.
65
+ """
66
+ def __init__(self, input_data):
67
+ super(AccumulativeErrorCompare, self).__init__(input_data)
68
+ self.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE
69
+
70
+ def _get_bound(self):
71
+ return StandardConfig.get_accumulative_error_bound(self.dtype)
72
+
73
+ def _pre_compare(self):
74
+ """
75
+ Prepares the comparison by calculating various metrics.
76
+
77
+ This method performs the following steps:
78
+ 1. Calculates the absolute benchmark values and their epsilon-adjusted versions.
79
+ 2. Determines masks for finite and infinite/NaN values in the outputs.
80
+ 3. Computes the absolute error between benchmark and device outputs.
81
+ 4. Retrieves the tolerance based on the data type.
82
+ 5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values.
83
+ 6. Determines the small value threshold and its absolute tolerance.
84
+ 7. Creates a mask for small values based on the benchmark values and finite mask.
85
+ 8. Creates a mask for normal values by excluding small values from the finite mask.
86
+ """
87
+ self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
88
+ self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
89
+ self.abs_err = self.stat_abs_error()
90
+ self.bound = self._get_bound()
91
+ self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps)
92
+ self.small_value, self.small_value_atol = self.get_small_value_threshold()
93
+ self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
94
+ self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask)
95
+
96
+ def _compute_metrics(self):
97
+ inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype,
98
+ self.bound)
99
+ rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.bound)
100
+ abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.bound)
101
+ eb = get_error_balance(self.bench_output, self.device_output)
102
+ return {
103
+ "inf_nan_error_ratio": inf_nan_error_ratio,
104
+ "rel_err_ratio": rel_err_ratio,
105
+ "abs_err_ratio": abs_err_ratio,
106
+ "eb": eb
107
+ }
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from abc import ABC, abstractmethod
19
+ import numpy as np
20
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import convert_str_to_float
21
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \
22
+ get_finite_and_infinite_mask, get_small_value_mask
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
24
+
25
+
26
+ class BaseCompare(ABC):
27
+ """
28
+ Base comparison class for benchmarking and device output.
29
+
30
+ This class provides a foundation for comparing benchmark outputs with device outputs.
31
+ It encapsulates the common logic for calculating accuracy metrics and
32
+ provides a framework for subclasses to implement specific comparison logic.
33
+
34
+ Attributes:
35
+ bench_output (np.ndarray): The output from the benchmark.
36
+ device_output (np.ndarray): The output from the device.
37
+ compare_column (object): The column object to store comparison results.
38
+ dtype (torch.dtype): The data type of the outputs.
39
+
40
+ Methods:
41
+ get_small_value_threshold(): Retrieves the small value threshold for the given data type.
42
+ stat_abs_bench_with_eps(): Calculates the absolute benchmark output with epsilon.
43
+ stat_abs_error(): Calculates the absolute error between the benchmark and device outputs.
44
+ stat_finite_and_infinite_mask(): Generates masks for finite and infinite/NaN values.
45
+ stat_small_value_mask(abs_bench, both_finite_mask, small_value): Creates a mask for small values.
46
+ compare(): Performs the comparison and computes metrics.
47
+ _pre_compare(): Pre-comparison hook for subclass-specific initialization.
48
+ _compute_metrics(): Computes the comparison metrics.
49
+ _post_compare(metrics): Post-comparison hook to update comparison results.
50
+
51
+ Note:
52
+ This class assumes that the input data is an instance of InputData containing the benchmark output,
53
+ device output, comparison column, and data type. Subclasses should implement the _pre_compare,
54
+ _compute_metrics, and _post_compare methods to provide specific comparison logic.
55
+
56
+ See Also:
57
+ InputData: The class containing input data for comparison.
58
+ StandardConfig: The class containing standard configuration values.
59
+ """
60
+ def __init__(self, input_data):
61
+ self.bench_output = input_data.bench_output
62
+ self.device_output = input_data.device_output
63
+ self.compare_column = input_data.compare_column
64
+ self.dtype = input_data.dtype
65
+ self.compare_algorithm = None
66
+
67
+ @staticmethod
68
+ def stat_small_value_mask(abs_bench, both_finite_mask, small_value):
69
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value)
70
+ return small_value_mask
71
+
72
+ @staticmethod
73
+ def _get_rel_err(abs_err, abs_bench_with_eps):
74
+ rel_err = abs_err / abs_bench_with_eps
75
+ return rel_err
76
+
77
+ @staticmethod
78
+ def _get_normal_value_mask(both_finite_mask, small_value_mask):
79
+ return np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
80
+
81
+ @abstractmethod
82
+ def _pre_compare(self):
83
+ raise NotImplementedError
84
+
85
+ def get_small_value_threshold(self):
86
+ small_value = StandardConfig.get_small_value(self.dtype, self.compare_algorithm)
87
+ small_value_atol = StandardConfig.get_small_value_atol(self.dtype, self.compare_algorithm)
88
+ return small_value, small_value_atol
89
+
90
+ def stat_abs_bench_with_eps(self):
91
+ abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(self.bench_output, self.dtype)
92
+ return abs_bench, abs_bench_with_eps
93
+
94
+ def stat_abs_error(self):
95
+ abs_err = get_abs_err(self.bench_output, self.device_output)
96
+ return abs_err
97
+
98
+ def stat_finite_and_infinite_mask(self):
99
+ both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(self.bench_output, self.device_output)
100
+ return both_finite_mask, inf_nan_mask
101
+
102
+ def compare(self):
103
+ self._pre_compare()
104
+ metrics = self._compute_metrics()
105
+ self._post_compare(metrics)
106
+
107
+ def _compute_metrics(self):
108
+ return {}
109
+
110
+ def _post_compare(self, metrics):
111
+ self.compare_column.update(metrics)
112
+
113
+
114
+ class BasePrecisionCompare:
115
+ def __init__(self, input_data):
116
+ self.row_npu = input_data.row_npu
117
+ self.row_gpu = input_data.row_gpu
118
+ self.dtype = input_data.dtype
119
+ self.compare_column = input_data.compare_column
120
+ self.compare_algorithm = None
121
+
122
+ @abstractmethod
123
+ def _get_status(self, metrics, inf_nan_consistency):
124
+ pass
125
+
126
+ @abstractmethod
127
+ def _compute_ratio(self):
128
+ pass
129
+
130
+ def compare(self):
131
+ metrics, inf_nan_consistency = self._compute_ratio()
132
+ compare_result = self._post_compare(metrics, inf_nan_consistency)
133
+ return compare_result
134
+
135
+ def _get_and_convert_values(self, column_name):
136
+ npu_value = self.row_npu.get(column_name)
137
+ gpu_value = self.row_gpu.get(column_name)
138
+ if npu_value is None:
139
+ raise ValueError(f"NPU value for column '{column_name}' is None.")
140
+ if gpu_value is None:
141
+ raise ValueError(f"GPU value for column '{column_name}' is None.")
142
+ npu_value = convert_str_to_float(npu_value)
143
+ gpu_value = convert_str_to_float(gpu_value)
144
+ return npu_value, gpu_value
145
+
146
+ def _post_compare(self, metrics, inf_nan_consistency):
147
+ metrics = self._get_status(metrics, inf_nan_consistency)
148
+ metrics.update({'compare_algorithm': self.compare_algorithm})
149
+ self.compare_column.update(metrics)
150
+ compare_result = metrics.get('compare_result')
151
+ return compare_result
@@ -0,0 +1,226 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from collections import namedtuple
20
+ import numpy as np
21
+
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
24
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_small_value_err_ratio, get_rel_err, \
25
+ get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err
26
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
27
+ is_inf_or_nan
28
+ from msprobe.core.common.const import CompareConst
29
+
30
+
31
+ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
32
+ 'rmse_inf_nan_consistency',
33
+ 'max_rel_inf_nan_consistency',
34
+ 'mean_rel_inf_nan_consistency',
35
+ 'eb_inf_nan_consistency'])
36
+
37
+
38
+ class BenchmarkCompare(BaseCompare):
39
+ """
40
+ Benchmark comparison class for calculating accuracy metrics.
41
+
42
+ This class is designed to compare the output of a benchmark test with the output of a device.
43
+ It calculates various metrics such as small value error ratio, RMSE, error balance, max relative error,
44
+ and mean relative error to assess the accuracy of the device output against the benchmark output.
45
+
46
+ Attributes:
47
+ bench_output (np.ndarray): The output from the benchmark.
48
+ device_output (np.ndarray): The output from the device.
49
+ dtype (torch.dtype): The data type of the outputs.
50
+ abs_bench (np.ndarray): The absolute value of the benchmark output.
51
+ abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
52
+ both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
53
+ inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
54
+ abs_err (np.ndarray): The absolute error between the benchmark and device outputs.
55
+ small_value (float): The small value threshold for comparison.
56
+ small_value_atol (float): The absolute tolerance for small values.
57
+ small_value_mask (np.ndarray): A mask indicating where values are small.
58
+ rel_err (np.ndarray): The relative error between the benchmark and device outputs.
59
+ abs_err_greater_mask (np.ndarray): A mask indicating where absolute error is greater than the small value
60
+ tolerance.
61
+
62
+ Methods:
63
+ _get_abs_err_greater_mask(small_value_atol): Calculates a mask where absolute error is greater than the small
64
+ value tolerance.
65
+ _compute_rel_err(): Computes the relative error between the benchmark and device outputs.
66
+ _pre_compare(): Prepares the comparison by calculating various metrics.
67
+ _compute_metrics(): Computes the accuracy metrics.
68
+
69
+ Note:
70
+ This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
71
+ 'compare_column' and 'dtype'.
72
+ The data type should be a PyTorch data type.
73
+
74
+ See Also:
75
+ BaseCompare: The base class for comparison classes.
76
+ InputData: The class containing input data for comparison.
77
+ """
78
+
79
+ def __init__(self, input_data):
80
+ super(BenchmarkCompare, self).__init__(input_data)
81
+ self.compare_algorithm = CompareConst.BENCHMARK
82
+
83
+ def _get_abs_err_greater_mask(self, small_value_atol):
84
+ abs_err_greater_mask = np.greater(self.abs_err, small_value_atol)
85
+ return abs_err_greater_mask
86
+
87
+ def _compute_rel_err(self):
88
+ rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask)
89
+ return rel_err
90
+
91
+ def _pre_compare(self):
92
+ self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
93
+ self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
94
+ self.abs_err = self.stat_abs_error()
95
+ self.small_value, self.small_value_atol = self.get_small_value_threshold()
96
+ self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
97
+ self.rel_err = self._compute_rel_err()
98
+ self.abs_err_greater_mask = self._get_abs_err_greater_mask(self.small_value_atol)
99
+
100
+ def _compute_metrics(self):
101
+ """
102
+ Computes a comprehensive set of error metrics for the comparison between benchmark and device outputs.
103
+
104
+ This method calculates five key metrics:
105
+ 1. Small Value Error Ratio: The proportion of errors associated with small values.
106
+ 2. Root Mean Square Error (RMSE): The square root of the mean of the squared errors.
107
+ 3. Error Balance (EB): A measure of the balance between the errors in the benchmark and device outputs.
108
+ 4. Maximum Relative Error: The maximum relative error between the benchmark and device outputs.
109
+ 5. Mean Relative Error: The mean relative error between the benchmark and device outputs.
110
+
111
+ Returns:
112
+ dict: A dictionary containing the computed error metrics.
113
+ The dictionary has the following keys:
114
+ - "small_value_err_ratio": The proportion of errors associated with small values.
115
+ - "max_rel_error": The maximum relative error.
116
+ - "mean_rel_error": The mean relative error.
117
+ - "rmse": The root mean square error.
118
+ - "eb": The error balance.
119
+ """
120
+ small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, self.abs_err_greater_mask)
121
+ rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask))
122
+ eb = get_error_balance(self.bench_output, self.device_output)
123
+ max_rel_error = get_max_rel_err(self.rel_err)
124
+ mean_rel_error = get_mean_rel_err(self.rel_err)
125
+
126
+ return {
127
+ "small_value_err_ratio": small_value_err_ratio,
128
+ "max_rel_error": max_rel_error,
129
+ "mean_rel_error": mean_rel_error,
130
+ "rmse": rmse,
131
+ "eb": eb
132
+ }
133
+
134
+
135
+ class BenchmarkPrecisionCompare(BasePrecisionCompare):
136
+ def __init__(self, input_data):
137
+ super().__init__(input_data)
138
+ self.compare_algorithm = CompareConst.BENCHMARK_COMPARE_ALGORITHM_NAME
139
+
140
+ @staticmethod
141
+ def get_final_status(status_list):
142
+ compare_result = CompareConst.PASS
143
+ if CompareConst.ERROR in status_list:
144
+ compare_result = CompareConst.ERROR
145
+ elif CompareConst.WARNING in status_list:
146
+ compare_result = CompareConst.WARNING
147
+ return compare_result
148
+
149
+ def _calc_ratio(self, column_name):
150
+ npu_value, gpu_value = self._get_and_convert_values(column_name)
151
+ if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
152
+ return check_inf_or_nan(npu_value, gpu_value, column_name)
153
+ else:
154
+ return calc_ratio(npu_value, gpu_value, str(self.dtype)), True, ""
155
+
156
+ def _compute_ratio(self):
157
+ compare_message = ""
158
+ small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = \
159
+ self._calc_ratio(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE)
160
+ compare_message += small_value_message
161
+ rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE)
162
+ compare_message += rmse_message
163
+ max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = \
164
+ self._calc_ratio(ApiPrecisionCompareColumn.MAX_REL_ERR)
165
+ compare_message += max_rel_message
166
+ mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = \
167
+ self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR)
168
+ compare_message += mean_rel_message
169
+ eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB)
170
+ compare_message += eb_message
171
+
172
+ metrics = {
173
+ CompareConst.SMALL_VALUE_ERR_RATIO: small_value_err_ratio,
174
+ CompareConst.RMSE_RATIO: rmse_ratio,
175
+ CompareConst.MAX_REL_ERR_RATIO: max_rel_err_ratio,
176
+ CompareConst.MEAN_REL_ERR_RATIO: mean_rel_err_ratio,
177
+ CompareConst.EB_RATIO: eb_ratio,
178
+ CompareConst.COMPARE_MESSAGE: compare_message
179
+ }
180
+
181
+ return metrics, \
182
+ BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
183
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
184
+ eb_inf_nan_consistency)
185
+
186
+ def _get_threshold(self, metric):
187
+ error_threshold = StandardConfig.get_benchmark_threshold(metric)
188
+ return error_threshold
189
+
190
+ def _get_single_metric_status(self, ratio, metric):
191
+ if is_inf_or_nan(ratio):
192
+ return CompareConst.PASS
193
+ error_threshold = self._get_threshold(metric)
194
+ if ratio > error_threshold:
195
+ return CompareConst.ERROR
196
+ return CompareConst.PASS
197
+
198
+ def _get_status(self, metrics, inf_nan_consistency):
199
+ small_value_err_ratio = metrics.get(CompareConst.SMALL_VALUE_ERR_RATIO)
200
+ rmse_ratio = metrics.get(CompareConst.RMSE_RATIO)
201
+ max_rel_err_ratio = metrics.get(CompareConst.MAX_REL_ERR_RATIO)
202
+ mean_rel_err_ratio = metrics.get(CompareConst.MEAN_REL_ERR_RATIO)
203
+ eb_ratio = metrics.get(CompareConst.EB_RATIO)
204
+
205
+ small_value_err_status = self._get_single_metric_status(small_value_err_ratio, CompareConst.SMALL_VALUE) \
206
+ if inf_nan_consistency.small_value_inf_nan_consistency else CompareConst.ERROR
207
+ rmse_status = self._get_single_metric_status(rmse_ratio, CompareConst.RMSE) \
208
+ if inf_nan_consistency.rmse_inf_nan_consistency else CompareConst.ERROR
209
+ max_rel_err_status = self._get_single_metric_status(max_rel_err_ratio, CompareConst.MAX_REL_ERR) \
210
+ if inf_nan_consistency.max_rel_inf_nan_consistency else CompareConst.ERROR
211
+ mean_rel_err_status = self._get_single_metric_status(mean_rel_err_ratio, CompareConst.MEAN_REL_ERR) \
212
+ if inf_nan_consistency.mean_rel_inf_nan_consistency else CompareConst.ERROR
213
+ eb_status = self._get_single_metric_status(eb_ratio, CompareConst.EB) \
214
+ if inf_nan_consistency.eb_inf_nan_consistency else CompareConst.ERROR
215
+ status_list = [small_value_err_status, rmse_status, max_rel_err_status, mean_rel_err_status]
216
+ compare_result = self.get_final_status(status_list)
217
+ status_dict = {
218
+ CompareConst.SMALL_VALUE_ERR_STATUS: small_value_err_status,
219
+ CompareConst.RMSE_STATUS: rmse_status,
220
+ CompareConst.MAX_REL_ERR_STATUS: max_rel_err_status,
221
+ CompareConst.MEAN_REL_ERR_STATUS: mean_rel_err_status,
222
+ CompareConst.EB_STATUS: eb_status
223
+ }
224
+ metrics.update(status_dict)
225
+ metrics.update({CompareConst.COMPARE_RESULT: compare_result})
226
+ return metrics
@@ -0,0 +1,68 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor
19
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
20
+
21
+
22
+ class BinaryCompare(BaseCompare):
23
+ """
24
+ Binary comparison class for comparing boolean tensors.
25
+
26
+ This class is designed to compare the output of a binary operation between a benchmark and a device.
27
+ It calculates the error rate of the comparison and provides a simple metric for assessing the accuracy.
28
+
29
+ Attributes:
30
+ bench_output (np.ndarray): The output from the benchmark.
31
+ device_output (np.ndarray): The output from the device.
32
+ compare_column (object): The column object to store comparison results.
33
+ dtype (torch.dtype): The data type of the outputs.
34
+
35
+ Methods:
36
+ _compute_metrics(): Computes the comparison metrics, specifically the error rate.
37
+
38
+ Note:
39
+ This class assumes that the input data is an instance of InputData containing the benchmark output,
40
+ device output, comparison column, and data type. The outputs are expected to be boolean tensors.
41
+
42
+ See Also:
43
+ BaseCompare: The base class for comparison classes.
44
+ compare_bool_tensor: The function used to compare boolean tensors.
45
+ """
46
+ def __init__(self, input_data):
47
+ super(BinaryCompare, self).__init__(input_data)
48
+
49
+ def _pre_compare(self):
50
+ pass
51
+
52
+ def _compute_metrics(self):
53
+ """
54
+ Computes the error rate metric for the comparison between benchmark and device outputs.
55
+
56
+ This method calculates the proportion of mismatches between the benchmark output and the device output.
57
+ It uses the `compare_bool_tensor` function to compare the two tensors and extract the error rate.
58
+
59
+ Returns:
60
+ dict: A dictionary containing the computed error rate metric.
61
+ The dictionary has the following key:
62
+ - "error_rate": The proportion of mismatches between the benchmark and device outputs.
63
+ """
64
+ error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output)
65
+
66
+ return {
67
+ "error_rate": error_rate
68
+ }