mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,15 +1,34 @@
1
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  import os
17
+ from tqdm import tqdm
3
18
 
4
- from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv
5
- from msprobe.core.common.utils import add_time_as_suffix
6
19
  from msprobe.core.common.const import Const, CompareConst, MsCompareConst
7
- from msprobe.mindspore.common.log import logger
20
+ from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json, load_yaml
21
+ from msprobe.core.common.utils import add_time_as_suffix
8
22
  from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo
9
23
  from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation
10
24
  from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms
25
+ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
11
26
  from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
12
27
  trim_output_compute_element_list)
28
+ from msprobe.mindspore.common.log import logger
29
+
30
+ cur_path = os.path.dirname(os.path.realpath(__file__))
31
+ yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
13
32
 
14
33
 
15
34
  class BasicInfoAndStatus:
@@ -21,6 +40,7 @@ class BasicInfoAndStatus:
21
40
  self.status = status
22
41
  self.err_msg = err_msg
23
42
 
43
+
24
44
  class ResultCsvEntry:
25
45
  def __init__(self) -> None:
26
46
  self.forward_pass_status = None
@@ -30,14 +50,21 @@ class ResultCsvEntry:
30
50
  self.overall_err_msg = None
31
51
 
32
52
 
53
+ class ProcessResultPacket:
54
+ def __init__(self, process_status, result, err_msg) -> None:
55
+ self.process_status = process_status
56
+ self.result = result
57
+ self.err_msg = err_msg
58
+
59
+
33
60
  class ApiAccuracyChecker:
34
- def __init__(self):
61
+ def __init__(self, args):
35
62
  self.api_infos = dict()
36
- self.results = dict()
63
+ self.data_manager = DataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager
37
64
 
38
65
  @staticmethod
39
66
  def run_and_compare_helper(api_info, api_name_str, api_input_aggregation, forward_or_backward):
40
- '''
67
+ """
41
68
  Args:
42
69
  api_info: ApiInfo
43
70
  api_name_str: str
@@ -51,7 +78,7 @@ class ApiAccuracyChecker:
51
78
  get mindspore api output, run torch api and get output.
52
79
  compare output.
53
80
  record compare result.
54
- '''
81
+ """
55
82
  # get output
56
83
  if global_context.get_is_constructed():
57
84
  # constructed situation, need use constructed input to run mindspore api getting tested_output
@@ -80,13 +107,13 @@ class ApiAccuracyChecker:
80
107
  compare_result_dict[compare_algorithm_name] = compare_result
81
108
 
82
109
  if compare_result_dict.get(CompareConst.COSINE).pass_status == CompareConst.PASS and \
83
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
110
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).pass_status == CompareConst.PASS:
84
111
  status = CompareConst.PASS
85
112
  err_msg = ""
86
113
  else:
87
114
  status = CompareConst.ERROR
88
- err_msg = compare_result_dict.get(CompareConst.COSINE).err_msg + \
89
- compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg
115
+ err_msg = (compare_result_dict.get(CompareConst.COSINE).err_msg +
116
+ compare_result_dict.get(CompareConst.MAX_ABS_ERR).err_msg)
90
117
  basic_info_status = \
91
118
  BasicInfoAndStatus(api_name_with_slot, bench_dtype, tested_dtype, shape, status, err_msg)
92
119
  output_list.append(tuple([api_name_str, forward_or_backward, basic_info_status, compare_result_dict]))
@@ -94,13 +121,13 @@ class ApiAccuracyChecker:
94
121
 
95
122
  @staticmethod
96
123
  def prepare_api_input_aggregation(api_info, forward_or_backward=Const.FORWARD):
97
- '''
124
+ """
98
125
  Args:
99
126
  api_info: ApiInfo
100
127
  forward_or_backward: str
101
128
  Returns:
102
129
  ApiInputAggregation
103
- '''
130
+ """
104
131
  forward_inputs = api_info.get_compute_element_list(Const.FORWARD, Const.INPUT)
105
132
  kwargs = api_info.get_kwargs()
106
133
  if forward_or_backward == Const.FORWARD:
@@ -109,19 +136,42 @@ class ApiAccuracyChecker:
109
136
  gradient_inputs = api_info.get_compute_element_list(Const.BACKWARD, Const.INPUT)
110
137
  return ApiInputAggregation(forward_inputs, kwargs, gradient_inputs)
111
138
 
139
+ @staticmethod
140
+ def is_api_checkable(api_name_str):
141
+ '''
142
+ Args:
143
+ api_name_str: str, e.g. "MintFunctional.relu.0.forward", key in data field of api_info.json
144
+ Returns:
145
+ is_checkable: bool
146
+ Description:
147
+ tell whether this api is checkable based on the key in "data" dict in api_info.json
148
+ '''
149
+ api_name_str_list = api_name_str.split(Const.SEP)
150
+ if len(api_name_str_list) < MsCompareConst.API_NAME_STR_LENGTH:
151
+ return False
152
+ api_type_str = api_name_str_list[0]
153
+ real_api_str = Const.SEP.join(api_name_str_list[1:-2])
154
+ api_list = load_yaml(yaml_path)
155
+ supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
156
+ if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL):
157
+ return True
158
+ if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list:
159
+ return True
160
+ return False
161
+
112
162
  def parse(self, api_info_path):
113
- with FileOpen(api_info_path, "r") as f:
114
- api_info_dict = json.load(f)
163
+ api_info_dict = load_json(api_info_path)
115
164
 
116
165
  # init global context
117
166
  task = check_and_get_from_json_dict(api_info_dict, MsCompareConst.TASK_FIELD,
118
- "task field in api_info.json",accepted_type=str,
167
+ "task field in api_info.json", accepted_type=str,
119
168
  accepted_value=(MsCompareConst.STATISTICS_TASK,
120
169
  MsCompareConst.TENSOR_TASK))
121
170
  is_constructed = task == MsCompareConst.STATISTICS_TASK
122
171
  if not is_constructed:
123
172
  dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
124
- "dump_data_dir field in api_info.json", accepted_type=str)
173
+ "dump_data_dir field in api_info.json",
174
+ accepted_type=str)
125
175
  else:
126
176
  dump_data_dir = ""
127
177
  global_context.init(is_constructed, dump_data_dir)
@@ -129,14 +179,12 @@ class ApiAccuracyChecker:
129
179
  api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
130
180
  "data field in api_info.json", accepted_type=dict)
131
181
  for api_name, api_info in api_info_data.items():
132
- is_mint = api_name.split(Const.SEP)[0] in \
133
- (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
134
- if not is_mint:
182
+ if not self.is_api_checkable(api_name):
135
183
  continue
136
184
  forbackward_str = api_name.split(Const.SEP)[-1]
137
185
  if forbackward_str not in (Const.FORWARD, Const.BACKWARD):
138
186
  logger.warning(f"api: {api_name} is not recognized as forward api or backward api, skip this.")
139
- api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
187
+ api_name = Const.SEP.join(api_name.split(Const.SEP)[:-1]) # www.xxx.yyy.zzz --> www.xxx.yyy
140
188
  if api_name not in self.api_infos:
141
189
  self.api_infos[api_name] = ApiInfo(api_name)
142
190
 
@@ -145,135 +193,87 @@ class ApiAccuracyChecker:
145
193
  else:
146
194
  self.api_infos[api_name].load_backward_info(api_info)
147
195
 
196
+ def process_forward(self, api_name_str, api_info):
197
+ """处理前向检查"""
198
+ if not api_info.check_forward_info():
199
+ logger.debug(f"api: {api_name_str} is lack of forward information, skip forward check.")
200
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
201
+ result=None,
202
+ err_msg=f"forward info of {api_name_str} is not found")
203
+ return process_result_packet
204
+
205
+ try:
206
+ forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
207
+ except Exception as e:
208
+ logger.warning(f"Exception occurs when getting inputs for {api_name_str} forward api. "
209
+ f"Skipping forward check. Detailed exception information: {e}.")
210
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
211
+ result=None, err_msg=f"{e}")
212
+ return process_result_packet
213
+
214
+ try:
215
+ forward_output_list = self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation,
216
+ Const.FORWARD)
217
+ except Exception as e:
218
+ logger.warning(f"Exception occurs when running and comparing {api_name_str} forward api. "
219
+ f"Detailed exception information: {e}.")
220
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
221
+ result=None, err_msg=f"{e}")
222
+ return process_result_packet
223
+
224
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
225
+ result=forward_output_list, err_msg="")
226
+ return process_result_packet
227
+
228
+ def process_backward(self, api_name_str, api_info):
229
+ """处理反向检查"""
230
+ if not api_info.check_backward_info():
231
+ logger.debug(f"api: {api_name_str} is lack of backward information, skipping backward check.")
232
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.API_NOT_FOUND,
233
+ result=None,
234
+ err_msg=f"backward info of {api_name_str} is not found")
235
+ return process_result_packet
236
+
237
+ try:
238
+ backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
239
+ except Exception as e:
240
+ logger.warning(f"Exception occurs when getting inputs for {api_name_str} backward api. "
241
+ f"Skipping backward check. Detailed exception information: {e}.")
242
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
243
+ result=None, err_msg=f"{e}")
244
+ return process_result_packet
245
+
246
+ try:
247
+ backward_output_list = self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation,
248
+ Const.BACKWARD)
249
+ except Exception as e:
250
+ logger.warning(f"Exception occurs when running and comparing {api_name_str} backward api. "
251
+ f"Detailed exception information: {e}.")
252
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.EXCEPTION_SKIP,
253
+ result=None, err_msg=f"{e}")
254
+ return process_result_packet
255
+
256
+ process_result_packet = ProcessResultPacket(process_status=MsCompareConst.ProcessStatus.SUCCESS,
257
+ result=backward_output_list, err_msg="")
258
+ return process_result_packet
259
+
148
260
  def run_and_compare(self):
149
- for api_name_str, api_info in self.api_infos.items():
150
- if not api_info.check_forward_info():
151
- logger.warning(f"api: {api_name_str} is lack of forward infomation, skip forward and backward check.")
152
- continue
153
- try:
154
- forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD)
155
- except Exception as e:
156
- logger.warning(f"exception occurs when getting inputs for {api_name_str} forward api. "
157
- f"skip forward and backward check. detailed exception information: {e}.")
158
- continue
159
- forward_output_list = None
160
- try:
161
- forward_output_list = \
162
- self.run_and_compare_helper(api_info, api_name_str, forward_inputs_aggregation, Const.FORWARD)
163
- except Exception as e:
164
- logger.warning(f"exception occurs when running and comparing {api_name_str} forward api. "
165
- f"detailed exception information: {e}.")
166
- self.record(forward_output_list)
167
-
168
- if not api_info.check_backward_info():
169
- logger.warning(f"api: {api_name_str} is lack of backward infomation, skip backward check.")
261
+ for api_name_str, api_info in tqdm(self.api_infos.items()):
262
+ if not self.data_manager.is_unique_api(api_name_str):
170
263
  continue
171
- try:
172
- backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD)
173
- except Exception as e:
174
- logger.warning(f"exception occurs when getting inputs for {api_name_str} backward api. "
175
- f"skip backward check. detailed exception information: {e}.")
176
- continue
177
- backward_output_list = None
178
- try:
179
- backward_output_list = \
180
- self.run_and_compare_helper(api_info, api_name_str, backward_inputs_aggregation, Const.BACKWARD)
181
- except Exception as e:
182
- logger.warning(f"exception occurs when running and comparing {api_name_str} backward api. "
183
- f"detailed exception information: {e}.")
184
- self.record(backward_output_list)
185
-
186
- def record(self, output_list):
187
- if output_list is None:
188
- return
189
- for output in output_list:
190
- api_real_name, forward_or_backward, basic_info, compare_result_dict = output
191
- key = tuple([api_real_name, forward_or_backward])
192
- if key not in self.results:
193
- self.results[key] = []
194
- self.results[key].append(tuple([basic_info, compare_result_dict]))
195
-
196
-
197
- def to_detail_csv(self, csv_dir):
198
- # detail_csv
199
- detail_csv = []
200
- detail_csv_header_basic_info = [
201
- MsCompareConst.DETAIL_CSV_API_NAME,
202
- MsCompareConst.DETAIL_CSV_BENCH_DTYPE,
203
- MsCompareConst.DETAIL_CSV_TESTED_DTYPE,
204
- MsCompareConst.DETAIL_CSV_SHAPE,
205
- ]
206
- detail_csv_header_compare_result = list(compare_algorithms.keys())
207
- detail_csv_header_status = [
208
- MsCompareConst.DETAIL_CSV_PASS_STATUS,
209
- MsCompareConst.DETAIL_CSV_MESSAGE,
210
- ]
211
-
212
- detail_csv_header = detail_csv_header_basic_info + detail_csv_header_compare_result + detail_csv_header_status
213
- detail_csv.append(detail_csv_header)
214
-
215
- for _, results in self.results.items():
216
- # detail csv
217
- for res in results:
218
- basic_info, compare_result_dict = res
219
- csv_row_basic_info = \
220
- [basic_info.api_name, basic_info.bench_dtype, basic_info.tested_dtype, basic_info.shape]
221
- csv_row_compare_result = list(compare_result_dict.get(algorithm_name).compare_value \
222
- for algorithm_name in detail_csv_header_compare_result)
223
- csv_row_status = [basic_info.status, basic_info.err_msg]
224
- csv_row = csv_row_basic_info + csv_row_compare_result + csv_row_status
225
- detail_csv.append(csv_row)
226
-
227
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.DETAIL_CSV_FILE_NAME))
228
- create_directory(csv_dir)
229
- write_csv(detail_csv, file_name, mode="w")
230
-
231
-
232
- def to_result_csv(self, csv_dir):
233
- result_csv_dict = dict()
234
- for key, results in self.results.items():
235
- api_real_name, forward_or_backward = key
236
- forward_or_backward_pass_status = CompareConst.PASS
237
- forward_or_backward_overall_err_msg = ""
238
- # detail csv
239
- for res in results:
240
- basic_info, _ = res
241
- if basic_info.status != CompareConst.PASS:
242
- forward_or_backward_pass_status = CompareConst.ERROR
243
- forward_or_backward_overall_err_msg += basic_info.err_msg
244
- forward_or_backward_overall_err_msg = \
245
- "" if forward_or_backward_pass_status == CompareConst.PASS else forward_or_backward_overall_err_msg
246
-
247
- #result_csv_dict
248
- if api_real_name not in result_csv_dict:
249
- result_csv_dict[api_real_name] = ResultCsvEntry()
250
- if forward_or_backward == Const.FORWARD:
251
- result_csv_dict[api_real_name].forward_pass_status = forward_or_backward_pass_status
252
- result_csv_dict[api_real_name].forward_err_msg = forward_or_backward_overall_err_msg
253
- else:
254
- result_csv_dict[api_real_name].backward_pass_status = forward_or_backward_pass_status
255
- result_csv_dict[api_real_name].backward_err_msg = forward_or_backward_overall_err_msg
256
-
257
- #result_csv
258
- result_csv = []
259
- result_csv_header = [
260
- MsCompareConst.DETAIL_CSV_API_NAME,
261
- MsCompareConst.RESULT_CSV_FORWARD_TEST_SUCCESS,
262
- MsCompareConst.RESULT_CSV_BACKWARD_TEST_SUCCESS,
263
- MsCompareConst.DETAIL_CSV_MESSAGE,
264
- ]
265
- result_csv.append(result_csv_header)
266
-
267
- for api_name, result_csv_entry in result_csv_dict.items():
268
- if result_csv_entry.forward_pass_status == CompareConst.PASS and \
269
- result_csv_entry.backward_pass_status == CompareConst.PASS:
270
- overall_err_msg = ""
271
- else:
272
- overall_err_msg = result_csv_entry.forward_err_msg + result_csv_entry.backward_err_msg
273
- row = [api_name, result_csv_entry.forward_pass_status,
274
- result_csv_entry.backward_pass_status, overall_err_msg]
275
- result_csv.append(row)
276
-
277
- file_name = os.path.join(csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME))
278
- create_directory(csv_dir)
279
- write_csv(result_csv, file_name, mode="w")
264
+
265
+ # 处理前向
266
+ process_result_packet = self.process_forward(api_name_str, api_info)
267
+ if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
268
+ self.data_manager.record(process_result_packet.result)
269
+ elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
270
+ self.data_manager.record_exception_skip(api_name_str, Const.FORWARD, process_result_packet.err_msg)
271
+
272
+ # 处理反向
273
+ process_result_packet = self.process_backward(api_name_str, api_info)
274
+ if process_result_packet.process_status is MsCompareConst.ProcessStatus.SUCCESS:
275
+ self.data_manager.record(process_result_packet.result)
276
+ elif process_result_packet.process_status == MsCompareConst.ProcessStatus.EXCEPTION_SKIP:
277
+ self.data_manager.record_exception_skip(api_name_str, Const.BACKWARD, process_result_packet.err_msg)
278
+
279
+ self.data_manager.save_results(api_name_str)
@@ -1,9 +1,25 @@
1
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
2
16
  from msprobe.core.common.const import Const
3
- from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
4
17
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
5
- from msprobe.mindspore.common.log import logger
6
18
  from msprobe.core.common.utils import is_invalid_pattern
19
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
20
+ from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict
21
+ from msprobe.mindspore.common.log import logger
22
+
7
23
 
8
24
  class ApiInfo:
9
25
  def __init__(self, api_name):
@@ -66,11 +82,10 @@ class ApiInfo:
66
82
  err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string"
67
83
  logger.error_log_with_exp(err_msg,
68
84
  ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
69
- if not isinstance(compute_element_info, (list, dict)):
70
- err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict"
85
+ if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None):
86
+ err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null"
71
87
  logger.error_log_with_exp(err_msg,
72
88
  ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed))
73
89
  kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info)
74
90
  for key_str, compute_element_info in kwargs_dict.items()}
75
91
  return kwargs_compute_element_dict
76
-
@@ -1,15 +1,27 @@
1
-
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
2
15
 
3
16
  import mindspore
4
17
  import torch
5
18
  from mindspore import ops
6
-
7
- from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
8
19
  from msprobe.core.common.const import Const, MsCompareConst
9
20
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
10
- from msprobe.mindspore.common.log import logger
11
- from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
21
+ from msprobe.mindspore.api_accuracy_checker.compute_element import ComputeElement
12
22
  from msprobe.mindspore.api_accuracy_checker.type_mapping import float_dtype_str_list, torch_dtype_to_dtype_str
23
+ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
24
+ from msprobe.mindspore.common.log import logger
13
25
 
14
26
 
15
27
  class ApiInputAggregation:
@@ -24,11 +36,23 @@ class ApiInputAggregation:
24
36
  self.kwargs = kwargs
25
37
  self.gradient_inputs = gradient_inputs
26
38
 
39
+
27
40
  api_parent_module_mapping = {
28
41
  (MsCompareConst.MINT, Const.MS_FRAMEWORK): mindspore.mint,
29
42
  (MsCompareConst.MINT, Const.PT_FRAMEWORK): torch,
30
43
  (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
31
- (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional
44
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
45
+ (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
46
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
47
+ }
48
+
49
+ api_parent_module_str_mapping = {
50
+ (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
51
+ (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
52
+ (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
53
+ (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
54
+ (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
55
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
32
56
  }
33
57
 
34
58
 
@@ -60,7 +84,7 @@ class ApiRunner:
60
84
  api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
61
85
 
62
86
  Return:
63
- api_type_str: str, Union["MintFunctional", "Mint"]
87
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
64
88
  api_sub_name: str, e.g. "relu"
65
89
  '''
66
90
  api_name_list = api_name_str.split(Const.SEP)
@@ -68,8 +92,8 @@ class ApiRunner:
68
92
  err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
69
93
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
70
94
  api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
71
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL]:
72
- err_msg = f"ApiRunner.get_info_from_name failed: not mint or mint.nn.functional api"
95
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]:
96
+ err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
73
97
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
74
98
 
75
99
  return api_type_str, api_sub_name
@@ -78,7 +102,7 @@ class ApiRunner:
78
102
  def get_api_instance(api_type_str, api_sub_name, api_platform):
79
103
  '''
80
104
  Args:
81
- api_type_str: str, Union["MintFunctional", "Mint"]
105
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
82
106
  api_sub_name: str, e.g. "relu"
83
107
  api_platform: str: Union["mindpore", "torch"]
84
108
 
@@ -92,9 +116,8 @@ class ApiRunner:
92
116
  '''
93
117
 
94
118
  api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
95
- module_str = "mindspore.mint." if api_platform == Const.MS_FRAMEWORK else "torch."
96
- submodule_str = "nn.functional." if api_type_str == MsCompareConst.MINT_FUNCTIONAL else ""
97
- full_api_name = module_str + submodule_str + api_sub_name
119
+ api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
120
+ full_api_name = api_parent_module_str + Const.SEP + api_sub_name
98
121
  if not hasattr(api_parent_module, api_sub_name):
99
122
  err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
100
123
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
@@ -115,7 +138,7 @@ class ApiRunner:
115
138
  gradient_inputs = api_input_aggregation.gradient_inputs
116
139
 
117
140
  if forward_or_backward == Const.FORWARD:
118
- forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
141
+ forward_result = api_instance(*inputs, **kwargs) # can be single tensor or tuple
119
142
  forward_result_tuple = convert_to_tuple(forward_result)
120
143
  res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in forward_result_tuple]
121
144
  else:
@@ -127,18 +150,20 @@ class ApiRunner:
127
150
  if api_platform == Const.MS_FRAMEWORK:
128
151
  if len(gradient_inputs) == 1:
129
152
  gradient_inputs = gradient_inputs[0]
153
+
130
154
  def api_with_kwargs(*forward_inputs):
131
155
  return api_instance(*forward_inputs, **kwargs)
156
+
132
157
  grad_func = ops.GradOperation(get_all=True, sens_param=True)(api_with_kwargs)
133
- backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
158
+ backward_result = grad_func(*inputs, gradient_inputs) # can be single tensor or tuple
134
159
  backward_result_tuple = convert_to_tuple(backward_result)
135
160
  res_compute_element_list = [ComputeElement(parameter=api_res) for api_res in backward_result_tuple]
136
161
  else:
137
- #set requires_grad
162
+ # set requires_grad
138
163
  requires_grad_index = []
139
164
  for index, tensor in enumerate(inputs):
140
165
  if isinstance(tensor, torch.Tensor) and \
141
- torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
166
+ torch_dtype_to_dtype_str.get(tensor.dtype) in float_dtype_str_list:
142
167
  setattr(tensor, "requires_grad", True)
143
168
  requires_grad_index.append(index)
144
169
  forward_results = api_instance(*inputs, **kwargs)
@@ -153,4 +178,4 @@ class ApiRunner:
153
178
  return res_compute_element_list
154
179
 
155
180
 
156
- api_runner = ApiRunner()
181
+ api_runner = ApiRunner()
@@ -1,12 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from abc import ABC, abstractmethod
2
17
 
3
18
  import mindspore
4
- import torch
5
19
  import numpy as np
6
-
20
+ import torch
21
+ from msprobe.core.common.const import CompareConst, MsCompareConst
7
22
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
8
23
  from msprobe.mindspore.common.log import logger
9
- from msprobe.core.common.const import CompareConst, MsCompareConst
24
+
10
25
 
11
26
  class CompareResult:
12
27
  def __init__(self, compare_value, pass_status, err_msg):
@@ -28,7 +43,7 @@ class BaseCompareAlgorithm(ABC):
28
43
  CompareConst.MAX_ABS_ERR: {
29
44
  CompareConst.PASS: "",
30
45
  CompareConst.ERROR: "max absolute difference is greater than " \
31
- f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
46
+ f"threshold: {CompareConst.MAX_ABS_ERR_THRESHOLD} ",
32
47
  CompareConst.SKIP: "two inputs are not valid for computing max absolute difference, skip comparing ",
33
48
  },
34
49
  CompareConst.MAX_RELATIVE_ERR: {
@@ -68,7 +83,7 @@ class BaseCompareAlgorithm(ABC):
68
83
  ndarray = tensor.to(torch.float64, copy=True).numpy()
69
84
  else:
70
85
  err_msg = "BaseCompareAlgorithm.convert_to_np_float64_ndarray failed: " \
71
- "input is not mindspore.Tensor or torch.Tensor"
86
+ "input is not mindspore.Tensor or torch.Tensor"
72
87
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
73
88
  return ndarray
74
89
 
@@ -189,9 +204,8 @@ class MaxRelativeDiffCompareAlgorithm(BaseCompareAlgorithm):
189
204
  return CompareConst.ERROR
190
205
 
191
206
 
192
-
193
207
  compare_algorithms = {
194
208
  CompareConst.COSINE: CosineSimilarityCompareAlgorithm(),
195
209
  CompareConst.MAX_ABS_ERR: MaxAbsoluteDiffCompareAlgorithm(),
196
210
  CompareConst.MAX_RELATIVE_ERR: MaxRelativeDiffCompareAlgorithm(),
197
- }
211
+ }