mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from abc import ABC, abstractmethod
19
+ import numpy as np
20
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import convert_str_to_float
21
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \
22
+ get_finite_and_infinite_mask, get_small_value_mask
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
24
+
25
+
26
+ class BaseCompare(ABC):
27
+ """
28
+ Base comparison class for benchmarking and device output.
29
+
30
+ This class provides a foundation for comparing benchmark outputs with device outputs.
31
+ It encapsulates the common logic for calculating accuracy metrics and
32
+ provides a framework for subclasses to implement specific comparison logic.
33
+
34
+ Attributes:
35
+ bench_output (np.ndarray): The output from the benchmark.
36
+ device_output (np.ndarray): The output from the device.
37
+ compare_column (object): The column object to store comparison results.
38
+ dtype (torch.dtype): The data type of the outputs.
39
+
40
+ Methods:
41
+ get_small_value_threshold(): Retrieves the small value threshold for the given data type.
42
+ stat_abs_bench_with_eps(): Calculates the absolute benchmark output with epsilon.
43
+ stat_abs_error(): Calculates the absolute error between the benchmark and device outputs.
44
+ stat_finite_and_infinite_mask(): Generates masks for finite and infinite/NaN values.
45
+ stat_small_value_mask(abs_bench, both_finite_mask, small_value): Creates a mask for small values.
46
+ compare(): Performs the comparison and computes metrics.
47
+ _pre_compare(): Pre-comparison hook for subclass-specific initialization.
48
+ _compute_metrics(): Computes the comparison metrics.
49
+ _post_compare(metrics): Post-comparison hook to update comparison results.
50
+
51
+ Note:
52
+ This class assumes that the input data is an instance of InputData containing the benchmark output,
53
+ device output, comparison column, and data type. Subclasses should implement the _pre_compare,
54
+ _compute_metrics, and _post_compare methods to provide specific comparison logic.
55
+
56
+ See Also:
57
+ InputData: The class containing input data for comparison.
58
+ StandardConfig: The class containing standard configuration values.
59
+ """
60
+ def __init__(self, input_data):
61
+ self.bench_output = input_data.bench_output
62
+ self.device_output = input_data.device_output
63
+ self.compare_column = input_data.compare_column
64
+ self.dtype = input_data.dtype
65
+ self.compare_algorithm = None
66
+
67
+ @staticmethod
68
+ def stat_small_value_mask(abs_bench, both_finite_mask, small_value):
69
+ small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value)
70
+ return small_value_mask
71
+
72
+ @staticmethod
73
+ def _get_rel_err(abs_err, abs_bench_with_eps):
74
+ rel_err = abs_err / abs_bench_with_eps
75
+ return rel_err
76
+
77
+ @staticmethod
78
+ def _get_normal_value_mask(both_finite_mask, small_value_mask):
79
+ return np.logical_and(both_finite_mask, np.logical_not(small_value_mask))
80
+
81
+ @abstractmethod
82
+ def _pre_compare(self):
83
+ raise NotImplementedError
84
+
85
+ def get_small_value_threshold(self):
86
+ small_value = StandardConfig.get_small_value(self.dtype, self.compare_algorithm)
87
+ small_value_atol = StandardConfig.get_small_value_atol(self.dtype, self.compare_algorithm)
88
+ return small_value, small_value_atol
89
+
90
+ def stat_abs_bench_with_eps(self):
91
+ abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(self.bench_output, self.dtype)
92
+ return abs_bench, abs_bench_with_eps
93
+
94
+ def stat_abs_error(self):
95
+ abs_err = get_abs_err(self.bench_output, self.device_output)
96
+ return abs_err
97
+
98
+ def stat_finite_and_infinite_mask(self):
99
+ both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(self.bench_output, self.device_output)
100
+ return both_finite_mask, inf_nan_mask
101
+
102
+ def compare(self):
103
+ self._pre_compare()
104
+ metrics = self._compute_metrics()
105
+ self._post_compare(metrics)
106
+
107
+ def _compute_metrics(self):
108
+ return {}
109
+
110
+ def _post_compare(self, metrics):
111
+ self.compare_column.update(metrics)
112
+
113
+
114
+ class BasePrecisionCompare:
115
+ def __init__(self, input_data):
116
+ self.row_npu = input_data.row_npu
117
+ self.row_gpu = input_data.row_gpu
118
+ self.dtype = input_data.dtype
119
+ self.compare_column = input_data.compare_column
120
+ self.compare_algorithm = None
121
+
122
+ @abstractmethod
123
+ def _get_status(self, metrics, inf_nan_consistency):
124
+ pass
125
+
126
+ @abstractmethod
127
+ def _compute_ratio(self):
128
+ pass
129
+
130
+ def compare(self):
131
+ metrics, inf_nan_consistency = self._compute_ratio()
132
+ compare_result = self._post_compare(metrics, inf_nan_consistency)
133
+ return compare_result
134
+
135
+ def _get_and_convert_values(self, column_name):
136
+ npu_value = self.row_npu.get(column_name)
137
+ gpu_value = self.row_gpu.get(column_name)
138
+ if npu_value is None:
139
+ raise ValueError(f"NPU value for column '{column_name}' is None.")
140
+ if gpu_value is None:
141
+ raise ValueError(f"GPU value for column '{column_name}' is None.")
142
+ npu_value = convert_str_to_float(npu_value)
143
+ gpu_value = convert_str_to_float(gpu_value)
144
+ return npu_value, gpu_value
145
+
146
+ def _post_compare(self, metrics, inf_nan_consistency):
147
+ metrics = self._get_status(metrics, inf_nan_consistency)
148
+ metrics.update({'compare_algorithm': self.compare_algorithm})
149
+ self.compare_column.update(metrics)
150
+ compare_result = metrics.get('compare_result')
151
+ return compare_result
@@ -0,0 +1,226 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from collections import namedtuple
20
+ import numpy as np
21
+
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
24
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_small_value_err_ratio, get_rel_err, \
25
+ get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err
26
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
27
+ is_inf_or_nan
28
+ from msprobe.core.common.const import CompareConst
29
+
30
+
31
+ BenchmarkInfNanConsistency = namedtuple('BenchmarkInfNanConsistency', ['small_value_inf_nan_consistency',
32
+ 'rmse_inf_nan_consistency',
33
+ 'max_rel_inf_nan_consistency',
34
+ 'mean_rel_inf_nan_consistency',
35
+ 'eb_inf_nan_consistency'])
36
+
37
+
38
+ class BenchmarkCompare(BaseCompare):
39
+ """
40
+ Benchmark comparison class for calculating accuracy metrics.
41
+
42
+ This class is designed to compare the output of a benchmark test with the output of a device.
43
+ It calculates various metrics such as small value error ratio, RMSE, error balance, max relative error,
44
+ and mean relative error to assess the accuracy of the device output against the benchmark output.
45
+
46
+ Attributes:
47
+ bench_output (np.ndarray): The output from the benchmark.
48
+ device_output (np.ndarray): The output from the device.
49
+ dtype (torch.dtype): The data type of the outputs.
50
+ abs_bench (np.ndarray): The absolute value of the benchmark output.
51
+ abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
52
+ both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
53
+ inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
54
+ abs_err (np.ndarray): The absolute error between the benchmark and device outputs.
55
+ small_value (float): The small value threshold for comparison.
56
+ small_value_atol (float): The absolute tolerance for small values.
57
+ small_value_mask (np.ndarray): A mask indicating where values are small.
58
+ rel_err (np.ndarray): The relative error between the benchmark and device outputs.
59
+ abs_err_greater_mask (np.ndarray): A mask indicating where absolute error is greater than the small value
60
+ tolerance.
61
+
62
+ Methods:
63
+ _get_abs_err_greater_mask(small_value_atol): Calculates a mask where absolute error is greater than the small
64
+ value tolerance.
65
+ _compute_rel_err(): Computes the relative error between the benchmark and device outputs.
66
+ _pre_compare(): Prepares the comparison by calculating various metrics.
67
+ _compute_metrics(): Computes the accuracy metrics.
68
+
69
+ Note:
70
+ This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
71
+ 'compare_column' and 'dtype'.
72
+ The data type should be a PyTorch data type.
73
+
74
+ See Also:
75
+ BaseCompare: The base class for comparison classes.
76
+ InputData: The class containing input data for comparison.
77
+ """
78
+
79
+ def __init__(self, input_data):
80
+ super(BenchmarkCompare, self).__init__(input_data)
81
+ self.compare_algorithm = CompareConst.BENCHMARK
82
+
83
+ def _get_abs_err_greater_mask(self, small_value_atol):
84
+ abs_err_greater_mask = np.greater(self.abs_err, small_value_atol)
85
+ return abs_err_greater_mask
86
+
87
+ def _compute_rel_err(self):
88
+ rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask)
89
+ return rel_err
90
+
91
+ def _pre_compare(self):
92
+ self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
93
+ self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
94
+ self.abs_err = self.stat_abs_error()
95
+ self.small_value, self.small_value_atol = self.get_small_value_threshold()
96
+ self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
97
+ self.rel_err = self._compute_rel_err()
98
+ self.abs_err_greater_mask = self._get_abs_err_greater_mask(self.small_value_atol)
99
+
100
+ def _compute_metrics(self):
101
+ """
102
+ Computes a comprehensive set of error metrics for the comparison between benchmark and device outputs.
103
+
104
+ This method calculates five key metrics:
105
+ 1. Small Value Error Ratio: The proportion of errors associated with small values.
106
+ 2. Root Mean Square Error (RMSE): The square root of the mean of the squared errors.
107
+ 3. Error Balance (EB): A measure of the balance between the errors in the benchmark and device outputs.
108
+ 4. Maximum Relative Error: The maximum relative error between the benchmark and device outputs.
109
+ 5. Mean Relative Error: The mean relative error between the benchmark and device outputs.
110
+
111
+ Returns:
112
+ dict: A dictionary containing the computed error metrics.
113
+ The dictionary has the following keys:
114
+ - "small_value_err_ratio": The proportion of errors associated with small values.
115
+ - "max_rel_error": The maximum relative error.
116
+ - "mean_rel_error": The mean relative error.
117
+ - "rmse": The root mean square error.
118
+ - "eb": The error balance.
119
+ """
120
+ small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, self.abs_err_greater_mask)
121
+ rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask))
122
+ eb = get_error_balance(self.bench_output, self.device_output)
123
+ max_rel_error = get_max_rel_err(self.rel_err)
124
+ mean_rel_error = get_mean_rel_err(self.rel_err)
125
+
126
+ return {
127
+ "small_value_err_ratio": small_value_err_ratio,
128
+ "max_rel_error": max_rel_error,
129
+ "mean_rel_error": mean_rel_error,
130
+ "rmse": rmse,
131
+ "eb": eb
132
+ }
133
+
134
+
135
+ class BenchmarkPrecisionCompare(BasePrecisionCompare):
136
+ def __init__(self, input_data):
137
+ super().__init__(input_data)
138
+ self.compare_algorithm = CompareConst.BENCHMARK_COMPARE_ALGORITHM_NAME
139
+
140
+ @staticmethod
141
+ def get_final_status(status_list):
142
+ compare_result = CompareConst.PASS
143
+ if CompareConst.ERROR in status_list:
144
+ compare_result = CompareConst.ERROR
145
+ elif CompareConst.WARNING in status_list:
146
+ compare_result = CompareConst.WARNING
147
+ return compare_result
148
+
149
+ def _calc_ratio(self, column_name):
150
+ npu_value, gpu_value = self._get_and_convert_values(column_name)
151
+ if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
152
+ return check_inf_or_nan(npu_value, gpu_value, column_name)
153
+ else:
154
+ return calc_ratio(npu_value, gpu_value, str(self.dtype)), True, ""
155
+
156
+ def _compute_ratio(self):
157
+ compare_message = ""
158
+ small_value_err_ratio, small_value_inf_nan_consistency, small_value_message = \
159
+ self._calc_ratio(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE)
160
+ compare_message += small_value_message
161
+ rmse_ratio, rmse_inf_nan_consistency, rmse_message = self._calc_ratio(ApiPrecisionCompareColumn.RMSE)
162
+ compare_message += rmse_message
163
+ max_rel_err_ratio, max_rel_inf_nan_consistency, max_rel_message = \
164
+ self._calc_ratio(ApiPrecisionCompareColumn.MAX_REL_ERR)
165
+ compare_message += max_rel_message
166
+ mean_rel_err_ratio, mean_rel_inf_nan_consistency, mean_rel_message = \
167
+ self._calc_ratio(ApiPrecisionCompareColumn.MEAN_REL_ERR)
168
+ compare_message += mean_rel_message
169
+ eb_ratio, eb_inf_nan_consistency, eb_message = self._calc_ratio(ApiPrecisionCompareColumn.EB)
170
+ compare_message += eb_message
171
+
172
+ metrics = {
173
+ CompareConst.SMALL_VALUE_ERR_RATIO: small_value_err_ratio,
174
+ CompareConst.RMSE_RATIO: rmse_ratio,
175
+ CompareConst.MAX_REL_ERR_RATIO: max_rel_err_ratio,
176
+ CompareConst.MEAN_REL_ERR_RATIO: mean_rel_err_ratio,
177
+ CompareConst.EB_RATIO: eb_ratio,
178
+ CompareConst.COMPARE_MESSAGE: compare_message
179
+ }
180
+
181
+ return metrics, \
182
+ BenchmarkInfNanConsistency(small_value_inf_nan_consistency, rmse_inf_nan_consistency,
183
+ max_rel_inf_nan_consistency, mean_rel_inf_nan_consistency,
184
+ eb_inf_nan_consistency)
185
+
186
+ def _get_threshold(self, metric):
187
+ error_threshold = StandardConfig.get_benchmark_threshold(metric)
188
+ return error_threshold
189
+
190
+ def _get_single_metric_status(self, ratio, metric):
191
+ if is_inf_or_nan(ratio):
192
+ return CompareConst.PASS
193
+ error_threshold = self._get_threshold(metric)
194
+ if ratio > error_threshold:
195
+ return CompareConst.ERROR
196
+ return CompareConst.PASS
197
+
198
+ def _get_status(self, metrics, inf_nan_consistency):
199
+ small_value_err_ratio = metrics.get(CompareConst.SMALL_VALUE_ERR_RATIO)
200
+ rmse_ratio = metrics.get(CompareConst.RMSE_RATIO)
201
+ max_rel_err_ratio = metrics.get(CompareConst.MAX_REL_ERR_RATIO)
202
+ mean_rel_err_ratio = metrics.get(CompareConst.MEAN_REL_ERR_RATIO)
203
+ eb_ratio = metrics.get(CompareConst.EB_RATIO)
204
+
205
+ small_value_err_status = self._get_single_metric_status(small_value_err_ratio, CompareConst.SMALL_VALUE) \
206
+ if inf_nan_consistency.small_value_inf_nan_consistency else CompareConst.ERROR
207
+ rmse_status = self._get_single_metric_status(rmse_ratio, CompareConst.RMSE) \
208
+ if inf_nan_consistency.rmse_inf_nan_consistency else CompareConst.ERROR
209
+ max_rel_err_status = self._get_single_metric_status(max_rel_err_ratio, CompareConst.MAX_REL_ERR) \
210
+ if inf_nan_consistency.max_rel_inf_nan_consistency else CompareConst.ERROR
211
+ mean_rel_err_status = self._get_single_metric_status(mean_rel_err_ratio, CompareConst.MEAN_REL_ERR) \
212
+ if inf_nan_consistency.mean_rel_inf_nan_consistency else CompareConst.ERROR
213
+ eb_status = self._get_single_metric_status(eb_ratio, CompareConst.EB) \
214
+ if inf_nan_consistency.eb_inf_nan_consistency else CompareConst.ERROR
215
+ status_list = [small_value_err_status, rmse_status, max_rel_err_status, mean_rel_err_status]
216
+ compare_result = self.get_final_status(status_list)
217
+ status_dict = {
218
+ CompareConst.SMALL_VALUE_ERR_STATUS: small_value_err_status,
219
+ CompareConst.RMSE_STATUS: rmse_status,
220
+ CompareConst.MAX_REL_ERR_STATUS: max_rel_err_status,
221
+ CompareConst.MEAN_REL_ERR_STATUS: mean_rel_err_status,
222
+ CompareConst.EB_STATUS: eb_status
223
+ }
224
+ metrics.update(status_dict)
225
+ metrics.update({CompareConst.COMPARE_RESULT: compare_result})
226
+ return metrics
@@ -0,0 +1,68 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor
19
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
20
+
21
+
22
+ class BinaryCompare(BaseCompare):
23
+ """
24
+ Binary comparison class for comparing boolean tensors.
25
+
26
+ This class is designed to compare the output of a binary operation between a benchmark and a device.
27
+ It calculates the error rate of the comparison and provides a simple metric for assessing the accuracy.
28
+
29
+ Attributes:
30
+ bench_output (np.ndarray): The output from the benchmark.
31
+ device_output (np.ndarray): The output from the device.
32
+ compare_column (object): The column object to store comparison results.
33
+ dtype (torch.dtype): The data type of the outputs.
34
+
35
+ Methods:
36
+ _compute_metrics(): Computes the comparison metrics, specifically the error rate.
37
+
38
+ Note:
39
+ This class assumes that the input data is an instance of InputData containing the benchmark output,
40
+ device output, comparison column, and data type. The outputs are expected to be boolean tensors.
41
+
42
+ See Also:
43
+ BaseCompare: The base class for comparison classes.
44
+ compare_bool_tensor: The function used to compare boolean tensors.
45
+ """
46
+ def __init__(self, input_data):
47
+ super(BinaryCompare, self).__init__(input_data)
48
+
49
+ def _pre_compare(self):
50
+ pass
51
+
52
+ def _compute_metrics(self):
53
+ """
54
+ Computes the error rate metric for the comparison between benchmark and device outputs.
55
+
56
+ This method calculates the proportion of mismatches between the benchmark output and the device output.
57
+ It uses the `compare_bool_tensor` function to compare the two tensors and extract the error rate.
58
+
59
+ Returns:
60
+ dict: A dictionary containing the computed error rate metric.
61
+ The dictionary has the following key:
62
+ - "error_rate": The proportion of mismatches between the benchmark and device outputs.
63
+ """
64
+ error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output)
65
+
66
+ return {
67
+ "error_rate": error_rate
68
+ }
@@ -0,0 +1,218 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch
19
+
20
+ from msprobe.core.common.const import CompareConst
21
+
22
+
23
+ class StandardConfig:
24
+ """
25
+ Standard configuration class for managing precision and comparison thresholds.
26
+
27
+ This class provides a centralized way to manage the small value thresholds, absolute tolerances,
28
+ and relative tolerances (rtol) used in precision comparisons. It allows for different thresholds
29
+ based on the data type, with default values provided for common data types.
30
+
31
+ Attributes:
32
+ _small_value (dict): A dictionary mapping data types to their corresponding small value thresholds.
33
+ _small_value_atol (dict): A dictionary mapping data types to their corresponding absolute tolerances.
34
+ _rtol (dict): A dictionary mapping data types to their corresponding relative tolerances.
35
+
36
+ Methods:
37
+ get_small_value(dtype): Retrieves the small value threshold for the given data type.
38
+ get_small_value_atol(dtype): Retrieves the absolute tolerance for the given data type.
39
+ get_rtol(dtype): Retrieves the relative tolerance for the given data type.
40
+
41
+ Example:
42
+ >>> small_value = StandardConfig.get_small_value(torch.float32)
43
+ >>> atol = StandardConfig.get_small_value_atol(torch.float32)
44
+ >>> rtol = StandardConfig.get_rtol(torch.float32)
45
+ >>> print(small_value, atol, rtol)
46
+ 1e-6 1e-9 1e-6
47
+
48
+ Note:
49
+ The data type is expected to be a PyTorch data type. If the data type is not found in the dictionary,
50
+ the default value is returned.
51
+
52
+ See Also:
53
+ torch.dtype: PyTorch data types.
54
+ """
55
+ _small_value = {
56
+ torch.float16: 2**-10,
57
+ torch.bfloat16: 2**-10,
58
+ torch.float32: 2**-20,
59
+ "default": 2**-20
60
+ }
61
+ _threshold_small_value_atol = {
62
+ torch.float16: 2**-16,
63
+ torch.bfloat16: 1e-16,
64
+ torch.float32: 2**-30,
65
+ "default": 2**-30
66
+ }
67
+ _benchmark_small_value_atol = {
68
+ torch.float16: 1e-16,
69
+ torch.bfloat16: 1e-16,
70
+ torch.float32: 2**-30,
71
+ "default": 2**-30
72
+ }
73
+ _rtol = {
74
+ torch.float16: 2**-10,
75
+ torch.bfloat16: 2**-8,
76
+ torch.float32: 2**-20,
77
+ "default": 2**-20
78
+ }
79
+ _accumulative_error_bound = {
80
+ torch.float16: 2**-8,
81
+ torch.bfloat16: 2**-7,
82
+ torch.float32: 2**-11,
83
+ "default": 2**-11
84
+ }
85
+ _small_value_threshold = {
86
+ 'error_threshold': 2,
87
+ 'warning_threshold': 1,
88
+ "default": 1
89
+ }
90
+ _rmse_threshold = {
91
+ 'error_threshold': 2,
92
+ 'warning_threshold': 1,
93
+ "default": 1
94
+ }
95
+ _max_rel_err_threshold = {
96
+ 'error_threshold': 10,
97
+ 'warning_threshold': 1,
98
+ "default": 1
99
+ }
100
+ _mean_rel_err_threshold = {
101
+ 'error_threshold': 2,
102
+ 'warning_threshold': 1,
103
+ "default": 1
104
+ }
105
+ _eb_threshold = {
106
+ 'error_threshold': 2,
107
+ 'warning_threshold': 1,
108
+ "default": 1
109
+ }
110
+ _minmum_err = {
111
+ 'torch.float16': 2**-11,
112
+ 'torch.bfloat16': 2**-8,
113
+ 'torch.float32': 2**-14,
114
+ 'default': 2**-14
115
+ }
116
+ _accumulative_error_eb_threshold = {
117
+ 'torch.float16': 2**-20,
118
+ 'torch.bfloat16': 2**-7,
119
+ 'torch.float32': 2**-14,
120
+ 'default': 2**-14
121
+ }
122
+
123
+ _fp32_mean_ulp_err_threshold = 64
124
+ ulp_err_proportion_ratio = 1
125
+ _fp32_ulp_err_proportion = 0.05
126
+ _fp16_ulp_err_proportion = 0.001
127
+ _special_samll_value = 1
128
+
129
+ @classmethod
130
+ def get_small_value(cls, dtype, standard):
131
+ if standard == CompareConst.ACCUMULATIVE_ERROR_COMPARE:
132
+ return cls._special_samll_value
133
+ return cls._small_value.get(dtype, cls._small_value["default"])
134
+
135
+ @classmethod
136
+ def get_small_value_atol(cls, dtype, standard):
137
+ standard_dict = {
138
+ CompareConst.ABSOLUTE_THRESHOLD: cls._threshold_small_value_atol,
139
+ CompareConst.BENCHMARK: cls._benchmark_small_value_atol
140
+ }
141
+ small_value_atol_standard = standard_dict.get(standard, cls._benchmark_small_value_atol)
142
+ return small_value_atol_standard.get(dtype, small_value_atol_standard["default"])
143
+
144
+ @classmethod
145
+ def get_rtol(cls, dtype):
146
+ return cls._rtol.get(dtype, cls._rtol["default"])
147
+
148
+ @classmethod
149
+ def get_small_value_threshold(cls, threshold_type):
150
+ return cls._small_value_threshold.get(threshold_type, cls._small_value_threshold["default"])
151
+
152
+ @classmethod
153
+ def get_rmse_threshold(cls, threshold_type):
154
+ return cls._rmse_threshold.get(threshold_type, cls._rmse_threshold["default"])
155
+
156
+ @classmethod
157
+ def get_max_rel_err_threshold(cls, threshold_type):
158
+ return cls._max_rel_err_threshold.get(threshold_type, cls._max_rel_err_threshold["default"])
159
+
160
+ @classmethod
161
+ def get_mean_rel_err_threshold(cls, threshold_type):
162
+ return cls._mean_rel_err_threshold.get(threshold_type, cls._mean_rel_err_threshold["default"])
163
+
164
+ @classmethod
165
+ def get_eb_threshold(cls, threshold_type):
166
+ return cls._eb_threshold.get(threshold_type, cls._eb_threshold["default"])
167
+
168
+ @classmethod
169
+ def get_benchmark_threshold(cls, metric):
170
+ metric_threshold_functions = {
171
+ 'small_value': StandardConfig.get_small_value_threshold,
172
+ 'rmse': StandardConfig.get_rmse_threshold,
173
+ 'max_rel_err': StandardConfig.get_max_rel_err_threshold,
174
+ 'mean_rel_err': StandardConfig.get_mean_rel_err_threshold,
175
+ 'eb': StandardConfig.get_eb_threshold
176
+ }
177
+
178
+ threshold_func = metric_threshold_functions.get(metric)
179
+ return threshold_func('error_threshold')
180
+
181
+ @classmethod
182
+ def get_fp32_mean_ulp_err_threshold(cls):
183
+ return cls._fp32_mean_ulp_err_threshold
184
+
185
+ @classmethod
186
+ def get_ulp_err_proportion_ratio_threshold(cls):
187
+ return cls.ulp_err_proportion_ratio
188
+
189
+ @classmethod
190
+ def get_fp32_ulp_err_proportion_threshold(cls):
191
+ return cls._fp32_ulp_err_proportion
192
+
193
+ @classmethod
194
+ def get_fp16_ulp_err_proportion_threshold(cls):
195
+ return cls._fp16_ulp_err_proportion
196
+
197
+ @classmethod
198
+ def get_ulp_threshold(cls, dtype):
199
+ ulp_err_proportion_ratio_threshold = StandardConfig.get_ulp_err_proportion_ratio_threshold()
200
+ if dtype == torch.float32:
201
+ mean_ulp_err_threshold = StandardConfig.get_fp32_mean_ulp_err_threshold()
202
+ ulp_err_proportion_threshold = StandardConfig.get_fp32_ulp_err_proportion_threshold()
203
+ return mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
204
+ else:
205
+ ulp_err_proportion_threshold = StandardConfig.get_fp16_ulp_err_proportion_threshold()
206
+ return None, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold
207
+
208
+ @classmethod
209
+ def get_minmum_err(cls, dtype):
210
+ return cls._minmum_err.get(dtype, cls._minmum_err["default"])
211
+
212
+ @classmethod
213
+ def get_accumulative_error_bound(cls, dtype):
214
+ return cls._accumulative_error_bound.get(dtype, cls._accumulative_error_bound["default"])
215
+
216
+ @classmethod
217
+ def get_accumulative_error_eb_threshold(cls, dtype):
218
+ return cls._accumulative_error_eb_threshold.get(dtype, cls._accumulative_error_eb_threshold["default"])