mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,77 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ # list of api that can be checked
17
+
18
+ tensor:
19
+ - add_
20
+ - add
21
+ - addmm_
22
+ - all
23
+ - allclose
24
+ - any
25
+ - bool
26
+ - byte
27
+ - ceil
28
+ - clamp
29
+ - contiguous
30
+ - copy_
31
+ - cos
32
+ - clone
33
+ - cumprod
34
+ - expand_as
35
+ - flatten
36
+ - float
37
+ - half
38
+ - int
39
+ - is_contiguous
40
+ - isnan
41
+ - item
42
+ - log
43
+ - log2
44
+ - long
45
+ - masked_fill
46
+ - max
47
+ - mean
48
+ - min
49
+ - numel
50
+ - numpy
51
+ - repeat
52
+ - repeat_interleave
53
+ - reshape
54
+ - round
55
+ - select
56
+ - sin
57
+ - size
58
+ - split
59
+ - sqrt
60
+ - square
61
+ - sub
62
+ - swapaxes
63
+ - to
64
+ - t
65
+ - tolist
66
+ - topk
67
+ - transpose
68
+ - trunc
69
+ - type
70
+ - unsqueeze
71
+ - view
72
+ - view_as
73
+ - fill_
74
+ - floor_
75
+ - clamp_
76
+ - type_as
77
+ - zero_
@@ -1,6 +1,69 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ from msprobe.core.common.file_utils import check_file_or_directory_path, create_directory
20
+ from msprobe.core.common.utils import Const, MsprobeBaseException
21
+
22
+
23
+ class UniqueDeviceAction(argparse.Action):
24
+ def __call__(self, parser, namespace, values, option_string=None):
25
+ unique_values = set(values)
26
+ if len(values) != len(unique_values):
27
+ parser.error("device id must be unique")
28
+ for device_id in values:
29
+ if not 0 <= device_id <= 4095:
30
+ parser.error(f"the argument 'device_id' must be in range [0, 4095], but got {device_id}")
31
+ setattr(namespace, self.dest, values)
32
+
33
+
1
34
  def add_api_accuracy_checker_argument(parser):
2
35
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
3
36
  help="<Required> The api param tool result file: generate from api param tool, "
4
37
  "a json file.")
5
38
  parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
6
- help="<optional> The ut task result out path.")
39
+ help="<optional> The ut task result out path.")
40
+ parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
41
+ help="<optional> the exit csv for continue")
42
+
43
+
44
+ def multi_add_api_accuracy_checker_argument(parser):
45
+ parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True,
46
+ help="<Required> The api param tool result file: generate from api param tool, "
47
+ "a json file.")
48
+ parser.add_argument("-o", "--out_path", dest="out_path", default="./", type=str, required=False,
49
+ help="<optional> The ut task result out path.")
50
+ parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False,
51
+ help="<optional> the exit csv for continue")
52
+ #以下属于多线程参数
53
+ parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int,
54
+ help="<optional> set device id to run ut, must be unique and in range 0-7",
55
+ default=[0], required=False, action=UniqueDeviceAction)
56
+
57
+
58
+ def check_args(args):
59
+ args.api_info_file = os.path.abspath(args.api_info_file)
60
+ check_file_or_directory_path(args.api_info_file)
61
+
62
+ if args.out_path == "":
63
+ args.out_path = "./"
64
+ args.out_path = os.path.abspath(args.out_path)
65
+ create_directory(args.out_path)
66
+
67
+ if args.result_csv_path:
68
+ args.result_csv_path = os.path.abspath(args.result_csv_path)
69
+ check_file_or_directory_path(args.result_csv_path)
@@ -1,21 +1,37 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
 
3
18
  import mindspore
4
- import torch
5
19
  import numpy as np
6
-
7
- from msprobe.mindspore.common.log import logger
20
+ import torch
21
+ from mindspore._c_expression import typing
22
+ from msprobe.core.common.const import Const
8
23
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
9
24
  from msprobe.core.common.file_utils import load_npy
10
- from msprobe.mindspore.api_accuracy_checker.type_mapping import (dtype_str_to_np_dtype, api_info_type_str_to_type,
25
+ from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
11
26
  ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
12
27
  dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
13
28
  dtype_str_to_torch_dtype, type_to_api_info_type_str,
14
29
  DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
15
- MINDSPORE_TENSOR_TYPE_STR, float_dtype_str_list,
16
- int_dtype_str_list)
17
- from msprobe.core.common.const import Const
30
+ MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
31
+ SLICE_TYPE_STR, TORCH_DTYPE_TYPE_STR,
32
+ float_dtype_str_list, int_dtype_str_list)
18
33
  from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
34
+ from msprobe.mindspore.common.log import logger
19
35
 
20
36
 
21
37
  class MstensorMetaData:
@@ -26,6 +42,12 @@ class MstensorMetaData:
26
42
  self.minimum = minimum
27
43
  self.shape = shape
28
44
 
45
+
46
+ class DtypeMetaData:
47
+ def __init__(self, dtype_str) -> None:
48
+ self.dtype_str = dtype_str
49
+
50
+
29
51
  class ComputeElement:
30
52
  def __init__(self, compute_element_info=None, parameter=None):
31
53
  self.supported_parameter_type = tuple(type_to_api_info_type_str.keys()) + tuple([torch.Tensor, tuple])
@@ -56,12 +78,10 @@ class ComputeElement:
56
78
  else:
57
79
  torch_dtype = dtype_str_to_torch_dtype.get(dtype_str)
58
80
 
59
- if dtype_str in float_dtype_str_list:
60
- middle_dtype = mindspore.float64
61
- elif dtype_str in int_dtype_str_list:
81
+ if dtype_str in int_dtype_str_list:
62
82
  middle_dtype = mindspore.int64
63
83
  else:
64
- middle_dtype = mindspore.uint64
84
+ middle_dtype = mindspore.float64
65
85
  np_ndarray = ms_tensor.astype(middle_dtype).numpy()
66
86
  torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
67
87
  return torch_tensor
@@ -84,10 +104,10 @@ class ComputeElement:
84
104
  else:
85
105
  ms_dtype = dtype_str_to_ms_dtype.get(dtype_str)
86
106
 
87
- if dtype_str in float_dtype_str_list:
88
- middle_dtype = torch.float64
89
- elif dtype_str in int_dtype_str_list:
107
+ if dtype_str in int_dtype_str_list:
90
108
  middle_dtype = torch.int64
109
+ else:
110
+ middle_dtype = torch.float64
91
111
  np_ndarray = torch_tensor.to(middle_dtype, copy=True).numpy()
92
112
  ms_tensor = mindspore.Tensor.from_numpy(np_ndarray).astype(ms_dtype)
93
113
  return ms_tensor
@@ -118,6 +138,11 @@ class ComputeElement:
118
138
  for compute_element in self.parameter])
119
139
  elif isinstance(self.parameter, self.supported_parameter_type):
120
140
  parameter_tmp = self.parameter
141
+ elif isinstance(self.parameter, DtypeMetaData):
142
+ if tensor_platform == Const.MS_FRAMEWORK:
143
+ parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
144
+ else:
145
+ parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
121
146
  elif isinstance(self.parameter, MstensorMetaData):
122
147
  mstensor_meta_data = self.parameter
123
148
  ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
@@ -130,13 +155,13 @@ class ComputeElement:
130
155
  parameter_tmp = mindspore.Tensor(ndarray, dtype=ms_dtype)
131
156
  else:
132
157
  err_msg = "ComputeElement.get_parameter failed: self.parameter type is not in " \
133
- "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
158
+ "(int, float, str, slice, bool, torch.Tensor, mindspore.Tensor, MstensorMetaData)"
134
159
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
135
160
 
136
161
  # if necessary, do transfer
137
162
  if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
138
163
  parameter = self.transfer_to_torch_tensor(parameter_tmp)
139
- elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform ==Const.MS_FRAMEWORK:
164
+ elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
140
165
  parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
141
166
  else:
142
167
  parameter = parameter_tmp
@@ -183,34 +208,38 @@ class ComputeElement:
183
208
  else:
184
209
  type_str = check_and_get_from_json_dict(compute_element_info, "type", "type field in api_info.json",
185
210
  accepted_type=str, accepted_value=api_info_type_str_to_type.keys())
186
-
211
+ self.shape = tuple()
212
+ self.dtype_str = type_str
187
213
  if type_str == MINDSPORE_TENSOR_TYPE_STR:
188
214
  self._init_from_mstensor_compute_element_info(compute_element_info)
189
- else: # type_str in ("slice", "int", "float", "bool")
215
+ else:
190
216
  value = check_and_get_from_json_dict(compute_element_info, "value", "value field in api_info.json")
191
- self.shape = tuple()
192
- self.dtype_str = type_str
193
- self.parameter = slice(*tuple(value)) if type_str == "slice" else value
217
+ if type_str == MINDSPORE_DTYPE_TYPE_STR:
218
+ self.parameter = DtypeMetaData(value)
219
+ elif type_str == SLICE_TYPE_STR:
220
+ self.parameter = slice(*tuple(value))
221
+ else: # type_str in ("str", "int", "float", "bool")
222
+ self.parameter = value
194
223
 
195
224
  def _init_from_mstensor_compute_element_info(self, compute_element_info):
196
225
  '''
197
226
  do not load real tensor, only record meta data
198
227
  '''
199
228
  dtype_str = check_and_get_from_json_dict(compute_element_info, "dtype", "dtype field in api_info.json",
200
- accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
229
+ accepted_type=str, accepted_value=dtype_str_to_ms_dtype.keys())
201
230
  shape = check_and_get_from_json_dict(compute_element_info, "shape", "shape field in api_info.json",
202
- accepted_type=(list,))
231
+ accepted_type=(list,))
203
232
  if global_context.get_is_constructed():
204
233
  maximum = check_and_get_from_json_dict(compute_element_info, "Max", "Max field in api_info.json",
205
- accepted_type=(int, float))
234
+ accepted_type=(int, float))
206
235
  minimum = check_and_get_from_json_dict(compute_element_info, "Min", "Min field in api_info.json",
207
- accepted_type=(int, float))
236
+ accepted_type=(int, float))
208
237
 
209
238
  npy_path = None
210
239
  else:
211
240
  maximum, minimum = None, None
212
241
  data_name = check_and_get_from_json_dict(compute_element_info, "data_name",
213
- "data_name field in api_info.json", accepted_type=(str,))
242
+ "data_name field in api_info.json", accepted_type=(str,))
214
243
  npy_path = os.path.join(global_context.get_dump_data_dir(), data_name)
215
244
  mstensor_meta_data = MstensorMetaData(dtype_str, npy_path, maximum, minimum, shape)
216
245
  self.parameter = mstensor_meta_data
@@ -219,9 +248,10 @@ class ComputeElement:
219
248
 
220
249
  def _init_with_parameter(self, parameter):
221
250
  self.parameter = parameter
251
+ self.shape = tuple()
222
252
  if not isinstance(parameter, self.supported_parameter_type):
223
253
  err_msg = "ComputeElement._init_with_parameter failed: " \
224
- "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
254
+ "parameter type is not in (int, float, str, slice, bool, torch.Tensor, mindspore.Tensor)"
225
255
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
226
256
  if isinstance(parameter, mindspore.Tensor):
227
257
  self.shape = tuple(parameter.shape)
@@ -229,11 +259,14 @@ class ComputeElement:
229
259
  elif isinstance(parameter, torch.Tensor):
230
260
  self.shape = tuple(parameter.shape)
231
261
  self.dtype_str = torch_dtype_to_dtype_str.get(parameter.dtype)
262
+ elif isinstance(parameter, typing.Type):
263
+ self.dtype_str = MINDSPORE_DTYPE_TYPE_STR
264
+ self.parameter = DtypeMetaData(ms_dtype_to_dtype_str.get(parameter))
265
+ elif isinstance(parameter, torch.dtype):
266
+ self.dtype_str = TORCH_DTYPE_TYPE_STR
267
+ self.parameter = DtypeMetaData(torch_dtype_to_dtype_str.get(parameter))
232
268
  elif isinstance(parameter, tuple):
233
- self.shape = tuple()
234
269
  self.dtype_str = TUPLE_TYPE_STR
235
270
  self.parameter = tuple([ComputeElement(parameter=param) for param in parameter])
236
271
  else:
237
- self.shape = tuple()
238
- self.dtype_str = \
239
- TUPLE_TYPE_STR if isinstance(parameter, tuple) else type_to_api_info_type_str.get(type(parameter))
272
+ self.dtype_str = type_to_api_info_type_str.get(type(parameter))
@@ -0,0 +1,301 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import csv
18
+
19
+ from msprobe.core.common.const import Const, CompareConst, MsCompareConst
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv
21
+ from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException
22
+ from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
23
+ from msprobe.core.common.file_utils import check_file_or_directory_path
24
+ from msprobe.mindspore.common.log import logger
25
+
26
+
27
+ class ResultCsvEntry:
28
+ def __init__(self) -> None:
29
+ self.forward_pass_status = None
30
+ self.backward_pass_status = None
31
+ self.forward_err_msg = ""
32
+ self.backward_err_msg = ""
33
+ self.overall_err_msg = None
34
+
35
+
36
+ def write_csv_header(csv_path, header_func):
37
+ """如果是第一次写入,则写入 CSV 表头"""
38
+ header = header_func() # 获取表头
39
+ logger.debug(f"Writing CSV header: {header}")
40
+ write_csv([header], csv_path, mode="a+")
41
+
42
+
43
+ def get_result_csv_header():
44
+ """获取结果 CSV 文件的表头"""
45
+ return [
46
+ MsCompareConst.DETAIL_CSV_API_NAME,
47
+ MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
48
+ MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
49
+ MsCompareConst.DETAIL_CSV_MESSAGE,
50
+ ]
51
+
52
+
53
+ def get_detail_csv_header():
54
+ """获取详细 CSV 文件的表头"""
55
+ detail_csv_header_basic_info = [
56
+ MsCompareConst.DETAIL_CSV_API_NAME,
57
+ MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
58
+ MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
59
+ MsCompareConst.DETAIL_CSV_SHAPE,
60
+ ]
61
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
62
+ detail_csv_header_status = [
63
+ MsCompareConst.DETAIL_CSV_PASS_STATUS,
64
+ MsCompareConst.DETAIL_CSV_MESSAGE,
65
+ ]
66
+ return detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
67
+
68
+
69
+ def check_csv_header(headers, required_constants, csv_path):
70
+ """校验 CSV 文件表头是否包含所有必需的常量"""
71
+ missing_constants = [const for const in required_constants if not any(const in header for header in headers)]
72
+
73
+ if missing_constants:
74
+ raise MsprobeBaseException(
75
+ MsprobeBaseException.MISSING_HEADER_ERROR,
76
+ f"{csv_path} 缺少以下必需的表头字段: {missing_constants}"
77
+ )
78
+
79
+
80
+ class DataManager:
81
+ def __init__(self, csv_dir, result_csv_path):
82
+ self.results = {}
83
+ self.results_exception_skip = {}
84
+ self.is_first_write = True # 标记用于添加表头
85
+ self.csv_dir = csv_dir
86
+ self.api_names_set = set() # 存储已经出现的 API 名称的集合
87
+ # 如果传入了 result_csv_path,则启用断点续检
88
+ if result_csv_path:
89
+ self.resume_from_last_csv(result_csv_path)
90
+ self.initialize_api_names_set(result_csv_path)
91
+ else:
92
+ # 默认情况下,设置输出路径为空,等待首次写入时初始化
93
+ self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
94
+ self.detail_out_path = os.path.join(
95
+ self.csv_dir,
96
+ os.path.basename(self.result_out_path).replace("result", "details")
97
+ )
98
+
99
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
100
+ check_file_or_directory_path(self.detail_out_path)
101
+
102
+ if self.result_out_path and os.path.exists(self.result_out_path):
103
+ check_file_or_directory_path(self.result_out_path)
104
+
105
+ def initialize_api_names_set(self, result_csv_path):
106
+ """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中"""
107
+ # 使用新的 read_csv 函数读取数据
108
+ csv_data = read_csv(result_csv_path, as_pd=False)
109
+
110
+ # 读取标题行
111
+ headers = csv_data[0] if csv_data else [] # 如果文件为空,则 headers 会为空
112
+
113
+ # 使用提取的表头校验函数
114
+ if check_csv_header(headers, get_result_csv_header(), result_csv_path):
115
+
116
+ # 获取 "API Name" 列的索引
117
+ api_name_index = None
118
+ for i, header in enumerate(headers):
119
+ if MsCompareConst.DETAIL_CSV_API_NAME in header: # CSV 文件的标题行包含了字节顺序标记,所以使用通过包含方式来查找
120
+ api_name_index = i
121
+ break
122
+
123
+ if api_name_index is None:
124
+ logger.warning(f"{result_csv_path} No column contains 'API Name'.")
125
+ return
126
+
127
+ # 读取每一行的 API 名称
128
+ for row in csv_data[1:]: # 跳过标题行,从第二行开始
129
+ if row and len(row) > api_name_index:
130
+ api_name = row[api_name_index]
131
+ if api_name:
132
+ self.api_names_set.add(api_name)
133
+
134
+ logger.debug(f"Initialized API names set from existing CSV: {self.api_names_set}")
135
+
136
+ def is_unique_api(self, api_name):
137
+ """检查 API 名称是否唯一,如果已经存在则返回 False,否则加入集合并返回 True"""
138
+ if api_name in self.api_names_set:
139
+ return False
140
+ self.api_names_set.add(api_name)
141
+ return True
142
+
143
+ def resume_from_last_csv(self, result_csv_path):
144
+ """从上次运行的 result_csv_path 恢复断点"""
145
+ # 获取上次的目录路径
146
+ last_dir = os.path.dirname(result_csv_path)
147
+
148
+ # 设置当前目录和输出路径,确保在首次写入时使用
149
+ self.csv_dir = last_dir
150
+ self.detail_out_path = os.path.join(last_dir, os.path.basename(result_csv_path).replace("result", "details"))
151
+ if self.detail_out_path and os.path.exists(self.detail_out_path):
152
+ check_file_or_directory_path(self.detail_out_path)
153
+ self.result_out_path = result_csv_path
154
+ self.is_first_write = False
155
+
156
+ def save_results(self, api_name_str):
157
+ if self.is_first_write:
158
+ # 直接写入表头
159
+ logger.info("Writing CSV headers for the first time.")
160
+ write_csv_header(self.detail_out_path, get_detail_csv_header)
161
+ write_csv_header(self.result_out_path, get_result_csv_header)
162
+ self.is_first_write = False # 写入后标记为 False,避免重复写入表头
163
+
164
+ """写入详细输出和结果摘要并清理结果"""
165
+ logger.debug("Starting to write detailed output to CSV.")
166
+ self.to_detail_csv(self.detail_out_path)
167
+ logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.")
168
+
169
+ logger.debug("Starting to write result summary to CSV.")
170
+ self.to_result_csv(self.result_out_path)
171
+ logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.")
172
+
173
+ # 清理记录,准备下一次调用
174
+ self.clear_results()
175
+
176
+ def record(self, output_list):
177
+ if output_list is None:
178
+ return
179
+ for output in output_list:
180
+ api_real_name, forward_or_backward, basic_info, compare_result_dict = output
181
+ key = (api_real_name, forward_or_backward)
182
+ if key not in self.results:
183
+ self.results[key] = []
184
+ self.results[key].append((basic_info, compare_result_dict))
185
+ logger.debug(f"Updated self.results for key {key}: {self.results[key]}")
186
+ logger.debug(f"Complete self.results after recording: {self.results}")
187
+
188
+ def record_exception_skip(self, api_name, forward_or_backward, err_msg):
189
+ '''
190
+ record exception_skip infomation into self.record_exception_skip.
191
+ self.record_exception_skip: dict{str: dict{"forward": str/None, "backward": str/None}}
192
+ string in key is api_name, string in value is err_msg
193
+ '''
194
+ if api_name not in self.results_exception_skip:
195
+ self.results_exception_skip[api_name] = {Const.FORWARD: None, Const.BACKWARD: None}
196
+ self.results_exception_skip[api_name][forward_or_backward] = err_msg
197
+
198
+ def clear_results(self):
199
+ """清空 self.results 数据"""
200
+ logger.debug("Clearing self.results data.")
201
+ self.results.clear()
202
+ self.results_exception_skip.clear()
203
+
204
+ def to_detail_csv(self, csv_path):
205
+ logger.debug("Preparing detail CSV headers and rows.")
206
+ detail_csv = []
207
+
208
+ detail_csv_header_compare_result = list(compare_algorithms.keys())
209
+
210
+ for _, results in self.results.items():
211
+ for res in results:
212
+ basic_info, compare_result_dict = res
213
+ csv_row_basic_info = [
214
+ basic_info.api_name,
215
+ basic_info.bench_dtype,
216
+ basic_info.tested_dtype,
217
+ basic_info.shape
218
+ ]
219
+ csv_row_compare_result = [
220
+ compare_result_dict.get(algorithm_name).compare_value
221
+ for algorithm_name in detail_csv_header_compare_result
222
+ ]
223
+ csv_row_status = [basic_info.status, basic_info.err_msg]
224
+ csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
225
+ detail_csv.append(csv_row)
226
+ logger.debug(f"Detail CSV row added: {csv_row}")
227
+
228
+ logger.debug(f"Writing detail CSV to {csv_path}.")
229
+ write_csv(detail_csv, csv_path, mode="a+")
230
+ logger.debug(f"Detail CSV written successfully to {csv_path}.")
231
+
232
+ def to_result_csv(self, csv_path):
233
+ '''
234
+ depend on both self.results and self.results_exception_skip
235
+ '''
236
+ logger.debug("Preparing result CSV data.")
237
+ result_csv = []
238
+
239
+ result_csv_dict = {}
240
+ for key, results in self.results.items():
241
+ api_real_name, forward_or_backward = key
242
+ pass_status = CompareConst.PASS
243
+ overall_err_msg = ""
244
+
245
+ for res in results:
246
+ basic_info, _ = res
247
+ if basic_info.status != CompareConst.PASS:
248
+ pass_status = CompareConst.ERROR
249
+ overall_err_msg += basic_info.err_msg
250
+
251
+ overall_err_msg = "" if pass_status == CompareConst.PASS else overall_err_msg
252
+
253
+ if api_real_name not in result_csv_dict:
254
+ result_csv_dict[api_real_name] = ResultCsvEntry()
255
+ if forward_or_backward == Const.FORWARD:
256
+ result_csv_dict[api_real_name].forward_pass_status = pass_status
257
+ result_csv_dict[api_real_name].forward_err_msg = overall_err_msg
258
+ else:
259
+ result_csv_dict[api_real_name].backward_pass_status = pass_status
260
+ result_csv_dict[api_real_name].backward_err_msg = overall_err_msg
261
+
262
+ for api_name, entry in result_csv_dict.items():
263
+ overall_err_msg = "" if (entry.forward_pass_status == CompareConst.PASS and
264
+ entry.backward_pass_status == CompareConst.PASS) else \
265
+ entry.forward_err_msg + entry.backward_err_msg
266
+ row = [
267
+ api_name,
268
+ entry.forward_pass_status,
269
+ entry.backward_pass_status,
270
+ overall_err_msg
271
+ ]
272
+ # change row if this api has excption_skip infomation
273
+ if api_name in self.results_exception_skip:
274
+ if self.results_exception_skip[api_name][Const.FORWARD] is not None:
275
+ row[1] = CompareConst.SKIP
276
+ row[-1] += self.results_exception_skip[api_name][Const.FORWARD]
277
+ if self.results_exception_skip[api_name][Const.BACKWARD] is not None:
278
+ row[2] = CompareConst.SKIP
279
+ row[-1] += self.results_exception_skip[api_name][Const.BACKWARD]
280
+ del self.results_exception_skip[api_name]
281
+ result_csv.append(row)
282
+ logger.debug(f"Result CSV row added: {row}")
283
+ for api_name in self.results_exception_skip:
284
+ current_exception_skip = self.results_exception_skip[api_name]
285
+ forward_status = None
286
+ backward_status = None
287
+ err_msg = ""
288
+ if current_exception_skip[Const.FORWARD] is not None:
289
+ forward_status = CompareConst.SKIP
290
+ err_msg += current_exception_skip[Const.FORWARD]
291
+ if current_exception_skip[Const.BACKWARD] is not None:
292
+ backward_status = CompareConst.SKIP
293
+ err_msg += current_exception_skip[Const.BACKWARD]
294
+ row = [api_name, forward_status, backward_status, err_msg]
295
+ result_csv.append(row)
296
+
297
+ write_csv(result_csv, csv_path, mode="a+")
298
+ logger.debug(f"Result CSV written successfully to {csv_path}.")
299
+
300
+ # 设置标记为 False,防止后续重复添加表头
301
+ self.is_first_write = False
@@ -1,9 +1,34 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker
2
17
 
18
+ from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker
19
+
20
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args
21
+
3
22
 
4
23
  def api_checker_main(args):
5
- api_accuracy_checker = ApiAccuracyChecker()
24
+ check_args(args)
25
+ api_accuracy_checker = ApiAccuracyChecker(args)
26
+ api_accuracy_checker.parse(args.api_info_file)
27
+ api_accuracy_checker.run_and_compare()
28
+
29
+
30
+ def mul_api_checker_main(args):
31
+ check_args(args)
32
+ api_accuracy_checker = MultiApiAccuracyChecker(args)
6
33
  api_accuracy_checker.parse(args.api_info_file)
7
34
  api_accuracy_checker.run_and_compare()
8
- api_accuracy_checker.to_detail_csv(args.out_path)
9
- api_accuracy_checker.to_result_csv(args.out_path)