mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,104 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Callable
19
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \
20
+ ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST
21
+ from msprobe.core.common.const import CompareConst
22
+
23
+
24
+ class StandardRegistry:
25
+ """
26
+ Registry class for managing comparison standards and functions.
27
+
28
+ This class provides a centralized registry for different comparison standards and their corresponding functions.
29
+ It allows for dynamic registration of comparison functions based on the standard category.
30
+
31
+ Attributes:
32
+ comparison_functions (dict): A dictionary mapping standard categories to their corresponding comparison
33
+ functions.
34
+ standard_categories (dict): A dictionary mapping standard names to their corresponding API categories.
35
+
36
+ Methods:
37
+ _get_standard_category(api_name, dtype): Determines the standard category for a given API name and data type.
38
+ register(standard, func): Registers a comparison function for a given standard category.
39
+ get_comparison_function(api_name, dtype): Retrieves the comparison function for a given API name and data type.
40
+
41
+ Note:
42
+ The data type is used to determine the standard category if it is not supported by binary comparison.
43
+ If the API name is not found in any standard category, it defaults to the 'benchmark' category.
44
+
45
+ See Also:
46
+ BaseCompare: The base class for comparison classes.
47
+ """
48
+ def __init__(self):
49
+ self.comparison_functions = {}
50
+ self.api_standard_function_map = {
51
+ CompareConst.ABSOLUTE_THRESHOLD: absolute_standard_api,
52
+ CompareConst.BINARY_CONSISTENCY: binary_standard_api,
53
+ CompareConst.ULP_COMPARE: ulp_standard_api,
54
+ CompareConst.THOUSANDTH_STANDARD: thousandth_standard_api,
55
+ CompareConst.ACCUMULATIVE_ERROR_COMPARE: accumulative_error_standard_api
56
+ }
57
+
58
+ def register(self, standard: str, func: Callable) -> None:
59
+ """
60
+ Registers a comparison function for a given standard category.
61
+
62
+ Args:
63
+ standard (str): The name of the standard category.
64
+ func (Callable): The comparison function to be registered.
65
+
66
+ Raises:
67
+ ValueError: If the standard category is not supported.
68
+ """
69
+ if not callable(func):
70
+ raise ValueError("The function to be registered must be callable.")
71
+ self.comparison_functions[standard] = func
72
+
73
+ def get_comparison_function(self, api_name, dtype=None):
74
+ standard = self._get_standard_category(api_name, dtype)
75
+ return self.comparison_functions.get(standard)
76
+
77
+ def _get_standard_category(self, api_name, dtype=None):
78
+ """
79
+ Determines the standard category for a given API name and data type.
80
+
81
+ This method checks if the provided data type is supported for binary comparison.
82
+ If it is, the method returns 'binary_consistency'. Otherwise, it iterates over the
83
+ api_standard_function_map to find a matching category for the API name.
84
+
85
+ Args:
86
+ api_name (str): The name of the API for which to determine the standard category.
87
+ dtype (type, optional): The data type to check against the BINARY_COMPARE_UNSUPPORT_LIST. Defaults to None.
88
+
89
+ Returns:
90
+ str: The name of the standard category that matches the API name and data type, or 'benchmark' if no match
91
+ is found.
92
+
93
+ Note:
94
+ This method assumes that the api_standard_function_map is properly populated with standard categories and
95
+ their corresponding API functions.
96
+ The BINARY_COMPARE_UNSUPPORT_LIST should be defined and contain all data types that are not supported for
97
+ binary comparison.
98
+ """
99
+ if dtype and dtype not in BINARY_COMPARE_UNSUPPORT_LIST:
100
+ return CompareConst.BINARY_CONSISTENCY
101
+ for name, category in self.api_standard_function_map.items():
102
+ if api_name in category:
103
+ return name
104
+ return CompareConst.BENCHMARK
@@ -0,0 +1,63 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rel_err_ratio
19
+ from msprobe.core.common.const import CompareConst
20
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare
21
+
22
+
23
+ class ThousandthStdCompare(BaseCompare):
24
+ """
25
+ Thousandth standard comparison class for calculating accuracy metrics.
26
+
27
+ A subclass of BaseCompare, specifically designed to compare the relative error
28
+ between benchmark and device outputs, focusing on errors within a thousandth (0.001) threshold.
29
+
30
+ Attributes:
31
+ rel_err_orign (float or array-like): The original relative error values to be compared.
32
+ compare_column (object): An object to store and update comparison metrics.
33
+
34
+ Methods:
35
+ _compute_metrics(): Computes the relative error metrics, specifically the thousandth error ratio.
36
+ """
37
+ def __init__(self, input_data):
38
+ self.rel_err_orign = input_data.rel_err_orign
39
+ self.compare_column = input_data.compare_column
40
+
41
+ def _pre_compare(self):
42
+ pass
43
+
44
+ def _compute_metrics(self):
45
+ """
46
+ Computes the relative error metrics for the comparison, specifically focusing on errors within a thousandth
47
+ (0.001) threshold.
48
+
49
+ This method calculates the proportion of relative errors that are within the thousandth threshold.
50
+ It uses the `get_rel_err_ratio` function to determine the ratio of relative errors that are less than or
51
+ equal to the
52
+ specified threshold defined in `CompareConst.THOUSAND_RATIO_THRESHOLD`.
53
+
54
+ Returns:
55
+ dict: A dictionary containing the computed relative error metric.
56
+ The dictionary has the following key:
57
+ - 'rel_err_thousandth': The proportion of relative errors within the thousandth threshold.
58
+ """
59
+ rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
60
+
61
+ return {
62
+ 'rel_err_thousandth': rel_err_thousandth
63
+ }
@@ -0,0 +1,200 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from collections import namedtuple
19
+ import numpy as np
20
+ import torch
21
+
22
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig
23
+ from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare, BasePrecisionCompare
24
+ from msprobe.core.common.const import Const, CompareConst
25
+ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import calc_ratio, get_ulp_err
26
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ApiPrecisionCompareColumn, check_inf_or_nan, \
27
+ is_inf_or_nan
28
+
29
+
30
+ UlpInfNanConsistency = namedtuple('UlpInfNanConsistency', ['mean_ulp_err_inf_nan_consistency',
31
+ 'ulp_err_proportion_ratio_inf_nan_consistency'])
32
+
33
+
34
+ class UlpCompare(BaseCompare):
35
+ """
36
+ Ulp compare comparison class for calculating accuracy metrics.
37
+
38
+ Attributes:
39
+ bench_output (array-like): The benchmark output values.
40
+ device_output (array-like): The device output values.
41
+ dtype (torch.dtype): The data type of the outputs (e.g., torch.float32 or torch.float16).
42
+ ulp_err (array-like): The ULP errors calculated from the benchmark and device outputs.
43
+
44
+ Methods:
45
+ _stat_max_ulp_err(ulp_err): Calculates the maximum ULP error.
46
+ _stat_mean_ulp_err(ulp_err): Calculates the mean ULP error.
47
+ _stat_ulp_error_proportion(ulp_err): Calculates the proportion of ULP errors exceeding a threshold.
48
+ _pre_compare(): Prepares for comparison by calculating ULP errors.
49
+ _compute_metrics(): Computes the ULP error metrics.
50
+ """
51
+ def __init__(self, input_data):
52
+ super(UlpCompare, self).__init__(input_data)
53
+
54
+ @staticmethod
55
+ def _stat_max_ulp_err(ulp_err):
56
+ return np.max(ulp_err)
57
+
58
+ @staticmethod
59
+ def _stat_mean_ulp_err(ulp_err):
60
+ return np.mean(ulp_err)
61
+
62
+ def _stat_ulp_error_proportion(self, ulp_err):
63
+ if self.dtype == torch.float32:
64
+ return np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / self.bench_output.size
65
+ else:
66
+ return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size
67
+
68
+ def _pre_compare(self):
69
+ self.ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype)
70
+
71
+ def _compute_metrics(self):
72
+ """
73
+ Computes the ULP error metrics for the comparison.
74
+
75
+ This method calculates three key metrics:
76
+ 1. Maximum ULP error: The maximum difference in ULPs between the benchmark and device outputs.
77
+ 2. Mean ULP error: The average difference in ULPs between the benchmark and device outputs.
78
+ 3. ULP error proportion: The proportion of ULP errors that exceed a certain threshold.
79
+
80
+ Args:
81
+ None (this method uses instance variables)
82
+
83
+ Returns:
84
+ dict: A dictionary containing the computed ULP error metrics.
85
+ The dictionary has the following keys:
86
+ - "max_ulp_error": The maximum ULP error.
87
+ - "mean_ulp_error": The mean ULP error.
88
+ - "ulp_error_proportion": The proportion of ULP errors exceeding the threshold.
89
+ """
90
+ max_ulp_error = self._stat_max_ulp_err(self.ulp_err)
91
+ mean_ulp_error = self._stat_mean_ulp_err(self.ulp_err)
92
+
93
+ ulp_error_proportion = self._stat_ulp_error_proportion(self.ulp_err)
94
+
95
+ return {
96
+ "max_ulp_error": max_ulp_error,
97
+ "mean_ulp_error": mean_ulp_error,
98
+ "ulp_error_proportion": ulp_error_proportion
99
+ }
100
+
101
+
102
+ class UlpPrecisionCompare(BasePrecisionCompare):
103
+ def __init__(self, input_data):
104
+ super().__init__(input_data)
105
+ self.compare_algorithm = CompareConst.ULP_COMPARE_ALGORITHM_NAME
106
+
107
+ @staticmethod
108
+ def _compute_ulp_err_proportion_ratio(npu_value, gpu_value, dtype):
109
+ column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION
110
+ if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
111
+ return check_inf_or_nan(npu_value, gpu_value, column_name)
112
+ else:
113
+ return calc_ratio(npu_value, gpu_value, dtype), True, ""
114
+
115
+ def _compute_mean_ulp_err(self):
116
+ column_name = ApiPrecisionCompareColumn.MEAN_ULP_ERR
117
+ npu_value, gpu_value = self._get_and_convert_values(column_name)
118
+ if is_inf_or_nan(npu_value) or is_inf_or_nan(gpu_value):
119
+ _, mean_ulp_err_inf_nan_consistency, message = check_inf_or_nan(npu_value, gpu_value, column_name)
120
+ return npu_value, mean_ulp_err_inf_nan_consistency, message
121
+ else:
122
+ return npu_value, True, ""
123
+
124
+ def _compute_ulp_err_proportion(self):
125
+ column_name = ApiPrecisionCompareColumn.ULP_ERR_PROPORTION
126
+ npu_value, gpu_value = self._get_and_convert_values(column_name)
127
+ return npu_value, gpu_value
128
+
129
+ def _get_status(self, metrics, inf_nan_consistency):
130
+ ulp_inf_nan_consistency = inf_nan_consistency.mean_ulp_err_inf_nan_consistency and \
131
+ inf_nan_consistency.ulp_err_proportion_ratio_inf_nan_consistency
132
+
133
+ if not ulp_inf_nan_consistency:
134
+ status_dict = {
135
+ CompareConst.ULP_ERR_STATUS: CompareConst.ERROR
136
+ }
137
+ compare_result = CompareConst.ERROR
138
+ metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + \
139
+ "ERROR: ULP误差不满足标准\n"
140
+ metrics.update({CompareConst.COMPARE_RESULT: compare_result})
141
+ return metrics
142
+
143
+ dtype = self.row_npu.get(ApiPrecisionCompareColumn.DEVICE_DTYPE)
144
+ mean_ulp_err = metrics.get(CompareConst.MEAN_ULP_ERR)
145
+ ulp_err_proportion = metrics.get(CompareConst.ULP_ERR_PROPORTION)
146
+ ulp_err_proportion_ratio = metrics.get(CompareConst.ULP_ERR_PROPORTION_RATIO)
147
+ if dtype == Const.TORCH_FLOAT32:
148
+ status, final_message = \
149
+ self._get_fp32_ulp_err_status(mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio)
150
+ else:
151
+ status, final_message = \
152
+ self._get_fp16_ulp_err_status(ulp_err_proportion, ulp_err_proportion_ratio)
153
+ metrics[CompareConst.COMPARE_MESSAGE] = metrics.get(CompareConst.COMPARE_MESSAGE, "") + final_message
154
+
155
+ status_dict = {
156
+ CompareConst.ULP_ERR_STATUS: status
157
+ }
158
+ compare_result = status
159
+ metrics.update(status_dict)
160
+ metrics.update({CompareConst.COMPARE_RESULT: compare_result})
161
+ return metrics
162
+
163
+ def _get_fp32_ulp_err_status(self, mean_ulp_err, ulp_err_proportion, ulp_err_proportion_ratio):
164
+ mean_ulp_err_threshold, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold = \
165
+ StandardConfig.get_ulp_threshold(torch.float32)
166
+ if mean_ulp_err < mean_ulp_err_threshold:
167
+ return CompareConst.PASS, ""
168
+ elif ulp_err_proportion < ulp_err_proportion_threshold:
169
+ return CompareConst.PASS, ""
170
+ elif ulp_err_proportion_ratio < ulp_err_proportion_ratio_threshold:
171
+ return CompareConst.PASS, ""
172
+ compare_message = "ERROR: ULP误差不满足标准\n"
173
+ return CompareConst.ERROR, compare_message
174
+
175
+ def _get_fp16_ulp_err_status(self, ulp_err_proportion, ulp_err_proportion_ratio):
176
+ _, ulp_err_proportion_threshold, ulp_err_proportion_ratio_threshold = \
177
+ StandardConfig.get_ulp_threshold(torch.float16)
178
+ if ulp_err_proportion < ulp_err_proportion_threshold:
179
+ return CompareConst.PASS, ""
180
+ elif ulp_err_proportion_ratio < ulp_err_proportion_ratio_threshold:
181
+ return CompareConst.PASS, ""
182
+ compare_message = "ERROR: ULP误差不满足标准\n"
183
+ return CompareConst.ERROR, compare_message
184
+
185
+ def _compute_ratio(self):
186
+ compare_message = ""
187
+ mean_ulp_err, mean_ulp_err_inf_nan_consistency, mean_ulp_err_message = self._compute_mean_ulp_err()
188
+ compare_message += mean_ulp_err_message
189
+ npu_ulp_err_proportion, gpu_ulp_err_proportion = self._compute_ulp_err_proportion()
190
+ ulp_err_proportion_ratio, ulp_err_proportion_ratio_inf_nan_consistency, ulp_err_proportion_ratio_message = \
191
+ self._compute_ulp_err_proportion_ratio(npu_ulp_err_proportion, gpu_ulp_err_proportion, str(self.dtype))
192
+ compare_message += ulp_err_proportion_ratio_message
193
+ metrics = {
194
+ CompareConst.MEAN_ULP_ERR: mean_ulp_err,
195
+ CompareConst.ULP_ERR_PROPORTION: npu_ulp_err_proportion,
196
+ CompareConst.ULP_ERR_PROPORTION_RATIO: ulp_err_proportion_ratio,
197
+ CompareConst.COMPARE_MESSAGE: compare_message
198
+ }
199
+ return metrics, UlpInfNanConsistency(mean_ulp_err_inf_nan_consistency,
200
+ ulp_err_proportion_ratio_inf_nan_consistency)
@@ -28,6 +28,7 @@ from msprobe.pytorch.common.log import logger
28
28
  from msprobe.pytorch.common.utils import load_pt
29
29
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
30
30
 
31
+
31
32
  TORCH_TYPE = ["torch.device", "torch.dtype"]
32
33
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
33
34
  FLOAT_TYPE = [
@@ -139,7 +140,12 @@ def gen_random_tensor(info, convert_type):
139
140
  high_info = [high, high_origin]
140
141
  data_dtype = info.get('dtype')
141
142
  shape = tuple(info.get('shape'))
142
- if not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
143
+ if 0 in shape:
144
+ low, low_origin = 0, 0
145
+ high, high_origin = 0, 0
146
+ low_info = [low, low_origin]
147
+ high_info = [high, high_origin]
148
+ elif not isinstance(low, (int, float)) or not isinstance(high, (int, float)):
143
149
  error_info = f'Data info Min: {low} , Max: {high}, info type must be int or float.'
144
150
  raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
145
151
  if data_dtype == "torch.bool":
@@ -305,6 +311,19 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
305
311
  kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
306
312
  elif value is None:
307
313
  kwargs_params[key] = None
314
+ elif key == 'atten_mask' and api_name == 'npu_fusion_attention':
315
+ sparse_mode = kwargs_params.get('sparse_mode', {})
316
+ if isinstance(sparse_mode, dict):
317
+ sparse_mode_value = sparse_mode.get('value', 0)
318
+ elif isinstance(sparse_mode, int):
319
+ sparse_mode_value = sparse_mode
320
+ else:
321
+ msg = f'The sparse_mode value is not int or dict, but {type(sparse_mode)}'
322
+ raise CompareException(CompareException.INVALID_PARAM_ERROR, msg)
323
+ if sparse_mode_value in Const.FA_SPECIAL_SPARSE_MODE:
324
+ kwargs_params[key] = gen_atten_mask(value, convert_type, real_data_path)
325
+ else:
326
+ kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
308
327
  elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
309
328
  kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
310
329
  elif value.get('type') in TORCH_TYPE:
@@ -314,6 +333,30 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
314
333
  return kwargs_params
315
334
 
316
335
 
336
+ def gen_atten_mask(info, convert_type, real_data_path):
337
+ """
338
+ Function Description:
339
+ Based on API basic information, generate input parameters: atten_mask, for API forward running
340
+ Parameter:
341
+ info: API basic information. Dict
342
+ convert_type: convert ori_type to dist_type flag.
343
+ real_data_path: the root directory for storing real data.
344
+ """
345
+ check_object_type(info, dict)
346
+ data_type = info.get('type')
347
+ data_path = info.get('datapath', info.get('data_name'))
348
+ data_path = get_full_data_path(data_path, real_data_path)
349
+ data = None
350
+ if data_type in TENSOR_DATA_LIST:
351
+ if data_path:
352
+ data = gen_real_tensor(data_path, convert_type)
353
+ else:
354
+ # 生成一个2048x2048的三角矩阵,对角线为1,其余为0
355
+ # 这是npu_fusion_attention的sparse_mode为[2, 3, 4]时,atten_mask的shape
356
+ data = torch.triu(torch.ones([2048, 2048]), diagonal=1).to(torch.bool)
357
+ return data
358
+
359
+
317
360
  def gen_torch_kwargs(kwargs_params, key, value):
318
361
  if value.get('type') != "torch.device":
319
362
  module_name, attribute_name = get_module_and_atttribute_name(value.get('value'))
@@ -341,6 +384,23 @@ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=No
341
384
  return kwargs_item_result
342
385
 
343
386
 
387
+ def get_output_dtype(api_info):
388
+ """
389
+ Function Description:
390
+ Based on API basic information, get the output data dtype
391
+ Parameter:
392
+ api_info: API basic information. Dict
393
+ """
394
+ output_dtype = None
395
+ output_info = api_info.get(Const.OUTPUT)
396
+ if output_info and isinstance(output_info[0], dict):
397
+ output_str_dtype = output_info[0].get(Const.DTYPE)
398
+ if output_str_dtype in Const.TORCH_FLOAT_DTYPE:
399
+ module_name, attribute_name = get_module_and_atttribute_name(output_str_dtype)
400
+ output_dtype = get_attribute(module_name, attribute_name)
401
+ return output_dtype
402
+
403
+
344
404
  def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_data_path=None):
345
405
  """
346
406
  Function Description:
@@ -367,4 +427,5 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
367
427
  else:
368
428
  logger.warning(f'Warning: No args in {api_info} ')
369
429
  args_params = []
370
- return args_params, kwargs_params
430
+ output_dtype = get_output_dtype(api_info)
431
+ return args_params, kwargs_params, output_dtype
@@ -33,13 +33,15 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
33
33
  from msprobe.pytorch.common import parse_json_info_forward_backward
34
34
  from msprobe.pytorch.common.log import logger
35
35
  from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
36
- check_path_before_create, create_directory
36
+ create_directory, load_json, save_json
37
37
  from msprobe.core.common.file_utils import remove_path
38
- from msprobe.core.common.const import FileCheckConst
38
+ from msprobe.core.common.const import FileCheckConst, Const
39
+ from msprobe.core.common.utils import CompareException
39
40
 
40
41
 
41
42
  def split_json_file(input_file, num_splits, filter_api):
42
43
  forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
44
+ input_dir = os.path.dirname(os.path.abspath(input_file))
43
45
  if filter_api:
44
46
  forward_data = preprocess_forward_content(forward_data)
45
47
  for data_name in list(forward_data.keys()):
@@ -47,9 +49,11 @@ def split_json_file(input_file, num_splits, filter_api):
47
49
  for data_name in list(backward_data.keys()):
48
50
  backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
49
51
 
50
- with FileOpen(input_file, 'r') as file:
51
- input_data = json.load(file)
52
- input_data.pop("data")
52
+ input_data = load_json(input_file)
53
+ if input_data.get("data") is None:
54
+ logger.error("Invalid input file, 'data' field is missing")
55
+ raise CompareException("Invalid input file, 'data' field is missing")
56
+ input_data.pop("data")
53
57
 
54
58
  items = list(forward_data.items())
55
59
  total_items = len(items)
@@ -68,9 +72,8 @@ def split_json_file(input_file, num_splits, filter_api):
68
72
  **backward_data
69
73
  }
70
74
  }
71
- split_filename = f"temp_part{i}.json"
72
- with FileOpen(split_filename, 'w') as split_file:
73
- json.dump(temp_data, split_file)
75
+ split_filename = os.path.join(input_dir, f"temp_part{i}.json")
76
+ save_json(split_filename, temp_data)
74
77
  split_files.append(split_filename)
75
78
 
76
79
  return split_files, total_items
@@ -122,7 +125,7 @@ def run_parallel_ut(config):
122
125
  if output == '':
123
126
  break
124
127
  if '[ERROR]' in output:
125
- logger.warning(output, end='')
128
+ logger.warning(output)
126
129
  sys.stdout.flush()
127
130
  except ValueError as e:
128
131
  logger.warning(f"An error occurred while reading subprocess output: {e}")
@@ -182,16 +185,19 @@ def run_parallel_ut(config):
182
185
 
183
186
 
184
187
  def prepare_config(args):
185
- check_link(args.api_info_file)
186
- api_info = os.path.realpath(args.api_info_file)
187
- check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
188
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
189
- check_path_before_create(out_path)
188
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
189
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
190
+ api_info = api_info_file_checker.common_check()
191
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
190
192
  create_directory(out_path)
191
193
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
192
194
  out_path = out_path_checker.common_check()
193
195
  split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
194
- config_path = os.path.realpath(args.config_path) if args.config_path else None
196
+ config_path = args.config_path if args.config_path else None
197
+ if config_path:
198
+ config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
199
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
200
+ config_path = config_path_checker.common_check()
195
201
  result_csv_path = args.result_csv_path or os.path.join(
196
202
  out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
197
203
  if not args.result_csv_path: