mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,365 @@
1
+ import json
2
+ import os
3
+ import math
4
+ from enum import Enum, auto
5
+ import torch
6
+ try:
7
+ import torch_npu
8
+ except ImportError:
9
+ pass
10
+ from tabulate import tabulate
11
+
12
+ TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
13
+ TORCH_BOOL_TYPE = ["torch.bool"]
14
+ TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int",
15
+ "torch.int64", "torch.long"]
16
+ TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float",
17
+ "torch.float64", "torch.double"]
18
+ TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"]
19
+ RAISE_PRECISION = {{
20
+ "torch.float16": torch.float32,
21
+ "torch.half": torch.float32,
22
+ "torch.bfloat16": torch.float32,
23
+ "torch.float32": torch.float64,
24
+ "torch.float": torch.float64
25
+ }}
26
+ THOUSANDTH_THRESHOLDING = 0.001
27
+ BACKWARD = 'backward'
28
+
29
+ class CompareStandard(Enum):
30
+ BINARY_EQUALITY_STANDARD = auto()
31
+ ABSOLUTE_THRESHOLD_STANDARD = auto()
32
+ ULP_ERROR_STANDARD = auto()
33
+ BENCHMARK_STANDARD = auto()
34
+ THOUSANDTH_STANDARD = auto()
35
+
36
+ def load_pt(pt_path, to_cpu=False):
37
+ pt_path = os.path.realpath(pt_path)
38
+ try:
39
+ if to_cpu:
40
+ pt = torch.load(pt_path, map_location=torch.device("cpu"))
41
+ else:
42
+ pt = torch.load(pt_path)
43
+ except Exception as e:
44
+ raise RuntimeError(f"load pt file {{pt_path}} failed") from e
45
+ return pt
46
+
47
+ def get_device():
48
+ if torch.cuda.is_available():
49
+ device = torch.device("cuda")
50
+ elif torch_npu.npu.is_available():
51
+ device = torch.device("npu")
52
+ else:
53
+ raise Exception("Error: This device is not NPU or GPU!")
54
+ return device
55
+
56
+
57
+ def generate_bool_tensor(low, high, shape):
58
+ low, high = int(low), int(high)
59
+ tensor = torch.randint(low, high + 1, shape)
60
+ bool_tensor = torch.gt(tensor, 0)
61
+ return bool_tensor
62
+
63
+
64
+ def generate_numerical_tensor(low, high, shape, data_dtype):
65
+ if data_dtype in TORCH_FLOAT_TYPE:
66
+ scale = high - low
67
+ rand01 = torch.rand(shape, dtype=eval(data_dtype))
68
+ tensor = rand01 * scale + low
69
+ elif data_dtype in TORCH_INT_TYPE:
70
+ low, high = int(low), int(high)
71
+ tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype))
72
+ else:
73
+ raise NotImplementedError(f"{{data_dtype}} is not supported!")
74
+ if torch.numel(tensor) == 0:
75
+ return tensor
76
+ tmp_tensor = tensor.reshape(-1)
77
+ tmp_tensor[0] = low
78
+ tmp_tensor[-1] = high
79
+ data = tmp_tensor.reshape(shape)
80
+ return data
81
+
82
+
83
+ def generate_random_tensor(info):
84
+ low, high = info.get('Min'), info.get('Max')
85
+ data_dtype = info.get('dtype')
86
+ shape = tuple(info.get('shape'))
87
+ if data_dtype == "torch.bool":
88
+ data = generate_bool_tensor(low, high, shape)
89
+ else:
90
+ data = generate_numerical_tensor(low, high, shape, data_dtype)
91
+ return data
92
+
93
+
94
+ def generate_real_tensor(data_path):
95
+ data_path = os.path.realpath(data_path)
96
+ data = load_pt(data_path, to_cpu = True)
97
+ return data
98
+
99
+
100
+ def generate_data(info):
101
+ data_type = info.get("type")
102
+ data_path = info.get("data_name")
103
+ data_grad = info.get("requires_grad")
104
+ if data_type in TENSOR_DATA_LIST:
105
+ if data_path:
106
+ data = generate_real_tensor(data_path)
107
+ else:
108
+ data = generate_random_tensor(info)
109
+ else:
110
+ data = info.get("value")
111
+ if data_grad == True:
112
+ data.requires_grad_(True)
113
+ return data
114
+
115
+
116
+ def get_input(propagation):
117
+ {args_element_assignment}
118
+ args_device = [{args_list_generator_device}]
119
+ args_bench = [{args_list_generator_bench}]
120
+ {kwargs_value_assignment}
121
+ kwargs_device = {{{kwargs_dict_generator_device}}}
122
+ kwargs_bench = {{{kwargs_dict_generator_bench}}}
123
+ {args_element_assignment_backward}
124
+ args_device_backward = [{args_list_generator_device_backward}]
125
+ args_bench_backward = [{args_list_generator_bench_backward}]
126
+ if propagation == BACKWARD:
127
+ return args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward
128
+ return args_device, kwargs_device, args_bench, kwargs_bench
129
+
130
+ def exec_api(args, kwargs, args_grad_input, propagation):
131
+ output = {api_type}.{api_name}(*args, **kwargs)
132
+ if propagation == BACKWARD:
133
+ args_input_tensor = [tensor for tensor in args if isinstance(tensor, torch.Tensor) and tensor.requires_grad]
134
+ args_input_tensor.extend(
135
+ [value for value in kwargs.values() if isinstance(value, torch.Tensor) and value.requires_grad])
136
+ output_backward = torch.autograd.grad(outputs=output, inputs=args_input_tensor, grad_outputs=args_grad_input)
137
+ return output_backward
138
+ return output
139
+
140
+ def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol):
141
+ out_bench = out_bench.to(out_device.dtype)
142
+ min = torch.finfo(out_device.dtype).min
143
+ max = torch.finfo(out_device.dtype).max
144
+ bench_clip = torch.clamp(out_bench, min=min, max=max)
145
+ device_clip = torch.clamp(out_device, min=min, max=max)
146
+ clipped_abs_ae = torch.abs(device_clip - bench_clip)
147
+ clipped_re = clipped_abs_ae / abs_bench_with_eps
148
+ pass_mask = torch.less_equal(clipped_re, rtol)
149
+ both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip))
150
+ pass_mask = torch.logical_or(pass_mask, both_nan_mask)
151
+ not_pass_mask = torch.logical_not(pass_mask)
152
+ not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask)
153
+ inf_nan_err_cnt = torch.sum(not_pass_mask)
154
+ return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask)
155
+
156
+
157
+ def compute_rmse(abs_err, normal_value_mask):
158
+ if torch.sum(normal_value_mask) == 0:
159
+ return 0
160
+ else:
161
+ masked_ae = torch.where(normal_value_mask, abs_err, 0)
162
+ mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask)
163
+ rmse = torch.sqrt(mse)
164
+ return rmse
165
+
166
+
167
+ def compute_error_balance(out_device, out_bench):
168
+ larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0))
169
+ smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0))
170
+ if torch.numel(out_bench) == 0:
171
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
172
+ error_balance = abs(larger_count - smaller_count) / torch.numel(out_bench)
173
+ return error_balance
174
+
175
+
176
+ def compare_tensor(out_device, out_bench, api_name):
177
+ if out_device.shape != out_bench.shape:
178
+ print("ERROR: shape of out_device and out_bench is not equal!")
179
+ return None
180
+ if torch.numel(out_bench) == 0:
181
+ print("Both out_device and out_bench have zero elements.")
182
+ return None
183
+ dtype_device = out_device.dtype
184
+ dtype_bench = out_bench.dtype
185
+ headers = ["Metric", "Value"]
186
+ table = [
187
+ ["Shape", out_bench.shape],
188
+ ["Dtype of out_device", out_device.dtype],
189
+ ["Dtype of out_bench", out_bench.dtype]
190
+ ]
191
+ if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \
192
+ or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \
193
+ or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE:
194
+ out_device = out_device.to(torch.device("cpu"))
195
+ if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD:
196
+ error_number = torch.sum(out_device != out_bench).item()
197
+ if torch.numel(out_bench) == 0:
198
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
199
+ error_rate = error_number / torch.numel(out_bench)
200
+ table.append(["Compare Standard", "Binary Equality Standard"])
201
+ table.append(["Error Rate", error_rate])
202
+ else:
203
+ abs_err = torch.abs(out_device - out_bench)
204
+ abs_bench = torch.abs(out_bench)
205
+ if dtype_bench == torch.float32:
206
+ eps = 2 ** -23
207
+ if dtype_bench == torch.float64:
208
+ eps = 2 ** -52
209
+ abs_bench_with_eps = abs_bench + eps
210
+ rel_err = torch.abs(abs_err / abs_bench_with_eps)
211
+ device_finite_mask = torch.isfinite(out_device)
212
+ bench_finite_mask = torch.isfinite(out_bench.to(dtype_device))
213
+ both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask)
214
+ inf_nan_mask = torch.logical_not(both_finite_mask)
215
+ if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD:
216
+ if dtype_device == torch.float16:
217
+ rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5
218
+ elif dtype_device == torch.bfloat16:
219
+ rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5
220
+ else:
221
+ rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9
222
+ small_value_mask = torch.less_equal(abs_bench, small_value)
223
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
224
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
225
+ inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol)
226
+ rel_err_mask = torch.greater(rel_err, rtol)
227
+ rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask)
228
+ if torch.sum(normal_value_mask) == 0:
229
+ rel_err_proportion = 0
230
+ else:
231
+ rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask)
232
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
233
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
234
+ if torch.sum(small_value_mask) == 0:
235
+ abs_err_proportion = 0
236
+ else:
237
+ abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
238
+ table.append(["Compare Standard", "Absolute Threshold Standard"])
239
+ table.append(["Relative Error Ratio", rel_err_proportion])
240
+ table.append(["Absolute Error Ratio", abs_err_proportion])
241
+ elif compare_standard == CompareStandard.ULP_ERROR_STANDARD:
242
+ if dtype_device == torch.float16:
243
+ min_eb, exponent_num = -14, 10
244
+ elif dtype_device == torch.bfloat16:
245
+ min_eb, exponent_num = -126, 7
246
+ else:
247
+ min_eb, exponent_num = -126, 23
248
+ eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench)))
249
+ eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape))
250
+ if dtype_device == torch.float32:
251
+ ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64)
252
+ else:
253
+ ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32)
254
+ ulp_err = torch.abs(ulp_err)
255
+ max_ulp_err = torch.max(ulp_err)
256
+ mean_ulp_err = torch.mean(ulp_err)
257
+ if torch.numel(out_bench) == 0:
258
+ raise ZeroDivisionError(f"ERROR: please check torch.numel out_bench, its value is {{torch.numel(out_bench)}}")
259
+ if dtype_device == torch.float32:
260
+ ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench)
261
+ else:
262
+ ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench)
263
+ table.append(["Compare Standard", "ULP error Standard"])
264
+ table.append(["Maximum ULP Error", max_ulp_err])
265
+ table.append(["Mean ULP Error", mean_ulp_err])
266
+ table.append(["ULP Error Proportion", ulp_err_proportion])
267
+ elif compare_standard == CompareStandard.THOUSANDTH_STANDARD:
268
+ rel_err_origin = torch.abs(abs_err / abs_bench_with_eps)
269
+ if torch.numel(rel_err_origin) == 0:
270
+ thousand_res = 1
271
+ else:
272
+ thousand_res = torch.divide(torch.sum(rel_err < THOUSANDTH_THRESHOLDING), torch.numel(rel_err_origin))
273
+ thousand_status = thousand_res > (1 - THOUSANDTH_THRESHOLDING)
274
+ table.append(["Compare Standard", "Thousandth Standard"])
275
+ table.append(["Thousandth ratio", thousand_res])
276
+ else:
277
+ if dtype_device == torch.float16:
278
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
279
+ elif dtype_device == torch.bfloat16:
280
+ small_value, small_value_atol = 1.0e-3, 1.0e-5
281
+ else:
282
+ small_value, small_value_atol = 1.0e-6, 1.0e-9
283
+ small_value_mask = torch.less_equal(abs_bench, small_value)
284
+ small_value_mask = torch.logical_and(small_value_mask, both_finite_mask)
285
+ normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask))
286
+ abs_err_mask = torch.greater(abs_err, small_value_atol)
287
+ abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask)
288
+ if torch.sum(small_value_mask) == 0:
289
+ small_value_err_proportion = 0
290
+ else:
291
+ small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask)
292
+ rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape))
293
+ if torch.max(rel_err) >= 0:
294
+ max_rel_err = torch.max(rel_err)
295
+ else:
296
+ max_rel_err = 0
297
+ if torch.sum(normal_value_mask) == 0:
298
+ mean_rel_err = 0
299
+ else:
300
+ mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask)
301
+ rmse = compute_rmse(abs_err, normal_value_mask)
302
+ error_balance = compute_error_balance(out_device, out_bench)
303
+ table.append(["Compare Standard", "Benchmark Standard"])
304
+ table.append(["Small Value Error Proportion", small_value_err_proportion])
305
+ table.append(["Maximum Relative Error", max_rel_err])
306
+ table.append(["Mean Relative Error", mean_rel_err])
307
+ table.append(["Root Mean Squared Error", rmse])
308
+ table.append(["Error Balance", error_balance])
309
+ else:
310
+ print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.")
311
+ return None
312
+ print(tabulate(table, headers, tablefmt='grid'))
313
+ return None
314
+
315
+
316
+ def compare_element(out_device, out_bench, api_name):
317
+ if type(out_device) != type(out_bench):
318
+ print("ERROR: out_device and out_bench is not the same type!")
319
+ return None
320
+ if isinstance(out_bench, torch.Tensor):
321
+ compare_tensor(out_device, out_bench, api_name)
322
+ elif isinstance(out_bench, (bool, int, float, str)):
323
+ if out_device == out_bench:
324
+ print("PASS: out_device and out_bench equals.")
325
+ else:
326
+ print("ERROR: out_device and out_bench is not equal!")
327
+ else:
328
+ print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.")
329
+ return None
330
+
331
+
332
+ def compare(out_device, out_bench, api_name):
333
+ print("Compare result:")
334
+ if type(out_device) != type(out_bench):
335
+ print("ERROR: out_device and out_bench is not the same type!")
336
+ return None
337
+ if isinstance(out_bench, (list, tuple)):
338
+ if len(out_device) != len(out_bench):
339
+ print("ERROR: len of out_device and out_bench is different!")
340
+ return None
341
+ for index, _ in enumerate(out_bench):
342
+ print(f"index {{index}}:")
343
+ compare_element(out_device[index], out_bench[index], api_name)
344
+ else:
345
+ compare_element(out_device, out_bench, api_name)
346
+
347
+ if __name__ == "__main__":
348
+ device = get_device()
349
+ api_name = "{api_name}"
350
+ propagation = "{propagation}"
351
+ compare_standard = {compare_standard}
352
+ torch.manual_seed({random_seed})
353
+ for i in range({iter_times}):
354
+ print(f"iter: {{i}}:")
355
+ if propagation == BACKWARD:
356
+ args_device, kwargs_device, args_bench, kwargs_bench, args_device_backward, args_bench_backward = get_input(propagation)
357
+ output_device = exec_api(args_device, kwargs_device, args_device_backward, propagation)
358
+ output_bench = exec_api(args_bench, kwargs_bench, args_bench_backward, propagation)
359
+ compare(output_device, output_bench, api_name)
360
+ else:
361
+ args_device, kwargs_device, args_bench, kwargs_bench = get_input(propagation)
362
+ output_device = exec_api(args_device, kwargs_device, None, propagation)
363
+ output_bench = exec_api(args_bench, kwargs_bench, None, propagation)
364
+ compare(output_device, output_bench, api_name)
365
+ print("Compare finished.")
@@ -0,0 +1,106 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import numpy as np
19
+
20
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \
21
+ check_small_value
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
24
+ from msprobe.core.common.const import CompareConst
25
+
26
+
27
+
28
+ class AbsolutethdCompare(BaseCompare):
29
+ """
30
+ Absolute threshold compare class.
31
+
32
+ This class is used to compare the absolute threshold of benchmark outputs and device outputs.
33
+ It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio
34
+ to determine the accuracy of the device output compared to the benchmark output.
35
+
36
+ Attributes:
37
+ bench_output (np.ndarray): The output from the benchmark.
38
+ device_output (np.ndarray): The output from the device.
39
+ dtype (torch.dtype): The data type of the outputs.
40
+ abs_bench (np.ndarray): The absolute value of the benchmark output.
41
+ abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
42
+ both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
43
+ inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
44
+ rtol (float): The relative tolerance for comparison.
45
+ rel_err (np.ndarray): The relative error between the benchmark and device outputs.
46
+ small_value (float): The small value threshold for comparison.
47
+ small_value_atol (float): The absolute tolerance for small values.
48
+ small_value_mask (np.ndarray): A mask indicating where values are small.
49
+ normal_value_mask (np.ndarray): A mask indicating where values are normal.
50
+
51
+ Methods:
52
+ _get_rtol(): Gets the relative tolerance based on the data type.
53
+ _get_rel_err(abs_bench_with_eps): Calculates the relative error.
54
+ _get_normal_value_mask(small_value_mask): Gets the mask for normal values.
55
+ _pre_compare(): Prepares the comparison by calculating various metrics.
56
+ _compute_metrics(): Computes the comparison metrics.
57
+
58
+ Note:
59
+ This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
60
+ 'compare_column' and 'dtype'.
61
+ The 'dtype' should be a PyTorch data type.
62
+
63
+ See Also:
64
+ BaseCompare: The base class for comparison classes.
65
+ StandardConfig: The class containing standard configuration values.
66
+ """
67
+ def __init__(self, input_data):
68
+ super(AbsolutethdCompare, self).__init__(input_data)
69
+ self.compare_algorithm = CompareConst.ABSOLUTE_THRESHOLD
70
+
71
+ def _get_rtol(self):
72
+ return StandardConfig.get_rtol(self.dtype)
73
+
74
+ def _pre_compare(self):
75
+ """
76
+ Prepares the comparison by calculating various metrics.
77
+
78
+ This method performs the following steps:
79
+ 1. Calculates the absolute benchmark values and their epsilon-adjusted versions.
80
+ 2. Determines masks for finite and infinite/NaN values in the outputs.
81
+ 3. Computes the absolute error between benchmark and device outputs.
82
+ 4. Retrieves the relative tolerance based on the data type.
83
+ 5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values.
84
+ 6. Determines the small value threshold and its absolute tolerance.
85
+ 7. Creates a mask for small values based on the benchmark values and finite mask.
86
+ 8. Creates a mask for normal values by excluding small values from the finite mask.
87
+ """
88
+ self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
89
+ self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
90
+ self.abs_err = self.stat_abs_error()
91
+ self.rtol = self._get_rtol()
92
+ self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps)
93
+ self.small_value, self.small_value_atol = self.get_small_value_threshold()
94
+ self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
95
+ self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask)
96
+
97
+ def _compute_metrics(self):
98
+ inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype,
99
+ self.rtol)
100
+ rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.rtol)
101
+ abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.small_value_atol)
102
+ return {
103
+ "inf_nan_error_ratio": inf_nan_error_ratio,
104
+ "rel_err_ratio": rel_err_ratio,
105
+ "abs_err_ratio": abs_err_ratio
106
+ }
@@ -0,0 +1,107 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import numpy as np
19
+
20
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \
21
+ check_small_value, get_error_balance
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
24
+ from msprobe.core.common.const import CompareConst
25
+
26
+
27
+ class AccumulativeErrorCompare(BaseCompare):
28
+ """
29
+ Absolute threshold compare class.
30
+
31
+ This class is used to compare the absolute threshold of benchmark outputs and device outputs.
32
+ It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio
33
+ to determine the accuracy of the device output compared to the benchmark output.
34
+
35
+ Attributes:
36
+ bench_output (np.ndarray): The output from the benchmark.
37
+ device_output (np.ndarray): The output from the device.
38
+ dtype (torch.dtype): The data type of the outputs.
39
+ abs_bench (np.ndarray): The absolute value of the benchmark output.
40
+ abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon.
41
+ both_finite_mask (np.ndarray): A mask indicating where both outputs are finite.
42
+ inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN.
43
+ bound (float): The tolerance for comparison.
44
+ rel_err (np.ndarray): The relative error between the benchmark and device outputs.
45
+ small_value (float): The small value threshold for comparison.
46
+ small_value_atol (float): The absolute tolerance for small values.
47
+ small_value_mask (np.ndarray): A mask indicating where values are small.
48
+ normal_value_mask (np.ndarray): A mask indicating where values are normal.
49
+
50
+ Methods:
51
+ _get_rtol(): Gets the relative tolerance based on the data type.
52
+ _get_rel_err(abs_bench_with_eps): Calculates the relative error.
53
+ _get_normal_value_mask(small_value_mask): Gets the mask for normal values.
54
+ _pre_compare(): Prepares the comparison by calculating various metrics.
55
+ _compute_metrics(): Computes the comparison metrics.
56
+
57
+ Note:
58
+ This class assumes that the input data is a dictionary containing 'bench_output', 'device_output',
59
+ 'compare_column' and 'dtype'.
60
+ The 'dtype' should be a PyTorch data type.
61
+
62
+ See Also:
63
+ BaseCompare: The base class for comparison classes.
64
+ StandardConfig: The class containing standard configuration values.
65
+ """
66
+ def __init__(self, input_data):
67
+ super(AccumulativeErrorCompare, self).__init__(input_data)
68
+ self.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE
69
+
70
+ def _get_bound(self):
71
+ return StandardConfig.get_accumulative_error_bound(self.dtype)
72
+
73
+ def _pre_compare(self):
74
+ """
75
+ Prepares the comparison by calculating various metrics.
76
+
77
+ This method performs the following steps:
78
+ 1. Calculates the absolute benchmark values and their epsilon-adjusted versions.
79
+ 2. Determines masks for finite and infinite/NaN values in the outputs.
80
+ 3. Computes the absolute error between benchmark and device outputs.
81
+ 4. Retrieves the tolerance based on the data type.
82
+ 5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values.
83
+ 6. Determines the small value threshold and its absolute tolerance.
84
+ 7. Creates a mask for small values based on the benchmark values and finite mask.
85
+ 8. Creates a mask for normal values by excluding small values from the finite mask.
86
+ """
87
+ self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps()
88
+ self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask()
89
+ self.abs_err = self.stat_abs_error()
90
+ self.bound = self._get_bound()
91
+ self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps)
92
+ self.small_value, self.small_value_atol = self.get_small_value_threshold()
93
+ self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value)
94
+ self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask)
95
+
96
+ def _compute_metrics(self):
97
+ inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype,
98
+ self.bound)
99
+ rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.bound)
100
+ abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.bound)
101
+ eb = get_error_balance(self.bench_output, self.device_output)
102
+ return {
103
+ "inf_nan_error_ratio": inf_nan_error_ratio,
104
+ "rel_err_ratio": rel_err_ratio,
105
+ "abs_err_ratio": abs_err_ratio,
106
+ "eb": eb
107
+ }