mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,7 +1,6 @@
1
1
  import argparse
2
2
  import os
3
3
  import csv
4
- import re
5
4
  import sys
6
5
  import time
7
6
  import gc
@@ -18,28 +17,35 @@ else:
18
17
  import torch
19
18
  from tqdm import tqdm
20
19
 
21
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api
20
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
21
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api
22
22
  from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
23
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, api_info_preprocess, \
23
+ from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
24
24
  initialize_save_path, UtDataProcessor
25
25
  from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
26
26
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
27
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
28
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
29
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
30
27
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
31
28
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
32
29
  from msprobe.core.common.file_check import FileOpen, FileChecker, \
33
- change_mode, check_file_suffix, check_link, check_path_before_create, create_directory
30
+ change_mode, check_path_before_create, create_directory
34
31
  from msprobe.pytorch.common.log import logger
32
+ from msprobe.core.common.utils import get_json_contents
33
+ from msprobe.pytorch.pt_config import parse_json_config
35
34
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
35
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, ApiData, move2device_exec
36
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
37
+
36
38
 
37
39
  current_time = time.strftime("%Y%m%d%H%M%S")
38
40
  UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
39
41
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
40
42
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
41
43
  RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
42
- 'save_error_data', 'is_continue_run_ut', 'real_data_path'])
44
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
45
+ 'black_list', 'error_data_path', 'online_config'])
46
+
47
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
48
+
43
49
  not_backward_list = ['repeat_interleave']
44
50
  not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
45
51
  not_raise_dtype_set = {'type_as'}
@@ -66,19 +72,6 @@ tqdm_params = {
66
72
  }
67
73
 
68
74
 
69
- def exec_api(api_type, api_name, args, kwargs):
70
- if api_type == "Functional":
71
- functional_api = FunctionalOPTemplate(api_name, str, False)
72
- out = functional_api.forward(*args, **kwargs)
73
- if api_type == "Tensor":
74
- tensor_api = TensorOPTemplate(api_name, str, False)
75
- out = tensor_api.forward(*args, **kwargs)
76
- if api_type == "Torch":
77
- torch_api = TorchOPTemplate(api_name, str, False)
78
- out = torch_api.forward(*args, **kwargs)
79
- return out
80
-
81
-
82
75
  def deal_detach(arg, to_detach=True):
83
76
  return arg.detach() if to_detach else arg
84
77
 
@@ -130,7 +123,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
130
123
  elif isinstance(arg_in, torch.Tensor):
131
124
  if need_backward and arg_in.requires_grad:
132
125
  arg_in = deal_detach(raise_bench_data_dtype(
133
- api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
126
+ api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
134
127
  temp_arg_in = arg_in * 1
135
128
  arg_in = temp_arg_in.type_as(arg_in)
136
129
  arg_in.retain_grad()
@@ -173,32 +166,48 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
173
166
 
174
167
  def run_ut(config):
175
168
  logger.info("start UT test")
176
- logger.info(f"UT task result will be saved in {config.result_csv_path}")
177
- logger.info(f"UT task details will be saved in {config.details_csv_path}")
169
+ if config.online_config.is_online:
170
+ logger.info(f"UT task result will be saved in {config.result_csv_path}".replace(".csv", "_rank*.csv"))
171
+ logger.info(f"UT task details will be saved in {config.details_csv_path}".replace(".csv", "_rank*.csv"))
172
+ else:
173
+ logger.info(f"UT task result will be saved in {config.result_csv_path}")
174
+ logger.info(f"UT task details will be saved in {config.details_csv_path}")
175
+
178
176
  if config.save_error_data:
179
- error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR))
180
- logger.info(f"UT task error_datas will be saved in {error_data_path}")
181
- compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut)
182
- with FileOpen(config.result_csv_path, 'r') as file:
183
- csv_reader = csv.reader(file)
184
- next(csv_reader)
185
- api_name_set = {row[0] for row in csv_reader}
177
+ logger.info(f"UT task error_datas will be saved in {config.error_data_path}")
178
+ compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config)
179
+
180
+ if config.online_config.is_online:
181
+ run_api_online(config, compare)
182
+ else:
183
+ with FileOpen(config.result_csv_path, 'r') as file:
184
+ csv_reader = csv.reader(file)
185
+ next(csv_reader)
186
+ api_name_set = {row[0] for row in csv_reader}
187
+ run_api_offline(config, compare, api_name_set)
188
+ for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
189
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
190
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
191
+ logger.info(f"UT task result csv is saved in {result_csv_path}")
192
+ logger.info(f"UT task details csv is saved in {details_csv_path}")
193
+ compare.print_pretest_result()
194
+
195
+
196
+ def run_api_offline(config, compare, api_name_set):
186
197
  for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)):
187
198
  if api_full_name in api_name_set:
188
199
  continue
189
- if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api
200
+ if is_unsupported_api(api_full_name):
190
201
  continue
202
+ [_, api_name, _] = api_full_name.split(Const.SEP)
191
203
  try:
192
- if msCheckerConfig.white_list:
193
- [_, api_name, _] = api_full_name.split(Const.SEP)
194
- if api_name not in set(msCheckerConfig.white_list):
195
- continue
204
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
205
+ continue
196
206
  data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
197
207
  is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
198
208
  if config.save_error_data:
199
- do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success)
209
+ do_save_error_data(api_full_name, data_info, config.error_data_path, is_fwd_success, is_bwd_success)
200
210
  except Exception as err:
201
- [_, api_name, _] = api_full_name.split(Const.SEP)
202
211
  if "expected scalar type Long" in str(err):
203
212
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
204
213
  f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
@@ -214,9 +223,71 @@ def run_ut(config):
214
223
  else:
215
224
  torch.npu.empty_cache()
216
225
  gc.collect()
217
- change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY)
218
- change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY)
219
- compare.print_pretest_result()
226
+
227
+
228
+ def run_api_online(config, compare):
229
+ attl = init_attl(config.online_config)
230
+ dispatcher = ConsumerDispatcher(compare=compare)
231
+ dispatcher.start(handle_func=run_torch_api_online, config=config)
232
+
233
+ def tcp_communication_flow():
234
+ while True:
235
+ api_data = attl.recv()
236
+ if api_data == 'STOP_':
237
+ continue
238
+ if api_data == 'KILL_':
239
+ time.sleep(1)
240
+ logger.info("==========接收到STOP信号==========")
241
+ dispatcher.stop()
242
+ attl.stop_serve()
243
+ time.sleep(1)
244
+ break
245
+ if not isinstance(api_data, ApiData):
246
+ continue
247
+ api_full_name = api_data.name
248
+ [_, api_name, _] = api_full_name.split(Const.SEP)
249
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
250
+ continue
251
+ dispatcher.update_consume_queue(api_data)
252
+
253
+ def shared_storage_communication_flow():
254
+ flag_num = -1
255
+ while True:
256
+ api_data = attl.download()
257
+ if api_data == "start":
258
+ if flag_num == -1:
259
+ flag_num += 1
260
+ flag_num += 1
261
+ if api_data == "end":
262
+ flag_num -= 1
263
+ if flag_num == 0:
264
+ dispatcher.stop()
265
+ break
266
+ if not isinstance(api_data, ApiData):
267
+ continue
268
+ api_full_name = api_data.name
269
+ [_, api_name, _] = api_full_name.split(Const.SEP)
270
+ if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
271
+ continue
272
+ dispatcher.update_consume_queue(api_data)
273
+
274
+ if config.online_config.nfs_path:
275
+ shared_storage_communication_flow()
276
+ else:
277
+ tcp_communication_flow()
278
+
279
+
280
+ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
281
+ """
282
+ run api(api_name) if api_name not in black_list and in white_list.
283
+ If api is both in black_list and black_list, black_list first.
284
+ return: False for exec api, True for not exec
285
+ """
286
+ if black_list and api_name in black_list:
287
+ return True
288
+ if white_list and api_name not in white_list:
289
+ return True
290
+ return False
220
291
 
221
292
 
222
293
  def is_unsupported_api(api_name):
@@ -227,16 +298,16 @@ def is_unsupported_api(api_name):
227
298
  return flag
228
299
 
229
300
 
230
- def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success):
301
+ def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
231
302
  if not is_fwd_success or not is_bwd_success:
232
- processor = UtDataProcessor(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR))
303
+ processor = UtDataProcessor(error_data_path)
233
304
  for element in data_info.in_fwd_data_list:
234
305
  processor.save_tensors_in_element(api_full_name + '.forward.input', element)
235
- processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_out)
236
- processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_out)
306
+ processor.save_tensors_in_element(api_full_name + '.forward.output.bench', data_info.bench_output)
307
+ processor.save_tensors_in_element(api_full_name + '.forward.output.device', data_info.device_output)
237
308
  processor.save_tensors_in_element(api_full_name + '.backward.input', data_info.grad_in)
238
- processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad_out)
239
- processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad_out)
309
+ processor.save_tensors_in_element(api_full_name + '.backward.output.bench', data_info.bench_grad)
310
+ processor.save_tensors_in_element(api_full_name + '.backward.output.device', data_info.device_grad)
240
311
 
241
312
 
242
313
  def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict):
@@ -273,7 +344,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
273
344
 
274
345
  if need_backward:
275
346
  if need_to_backward(grad_index, out):
276
- backward_args = backward_content[api_full_name].get("grad_output")
347
+ backward_args = backward_content[api_full_name].get("input")
277
348
  grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
278
349
  bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
279
350
  bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
@@ -285,6 +356,20 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
285
356
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
286
357
 
287
358
 
359
+ def run_torch_api_online(api_full_name, api_data, backward_content):
360
+ in_fwd_data_list = []
361
+ [api_type, api_name, _] = api_full_name.split(Const.SEP)
362
+ args, kwargs, out = api_data.args, api_data.kwargs, api_data.result
363
+ in_fwd_data_list.append(args)
364
+ in_fwd_data_list.append(kwargs)
365
+ if kwargs.get("device"):
366
+ del kwargs["device"]
367
+
368
+ device_out = exec_api(api_type, api_name, args, kwargs)
369
+ device_out = move2device_exec(device_out, "cpu")
370
+ return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
371
+
372
+
288
373
  def get_api_info(api_info_dict, api_name, real_data_path):
289
374
  convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
290
375
  need_grad = True
@@ -314,45 +399,31 @@ def run_backward(args, grad, grad_index, out):
314
399
  return grad_out
315
400
 
316
401
 
317
- def initialize_save_error_data():
318
- error_data_path = msCheckerConfig.error_data_path
402
+ def initialize_save_error_data(error_data_path):
319
403
  check_path_before_create(error_data_path)
320
404
  create_directory(error_data_path)
321
- error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR,
405
+ error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
322
406
  ability=FileCheckConst.WRITE_ABLE)
323
407
  error_data_path = error_data_path_checker.common_check()
324
- initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
408
+ error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
409
+ return error_data_path
325
410
 
326
411
 
327
- def get_validated_result_csv_path(result_csv_path, mode):
328
- if mode not in ['result', 'detail']:
329
- raise ValueError("The csv mode must be result or detail")
330
- result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
331
- file_type=FileCheckConst.CSV_SUFFIX)
332
- validated_result_csv_path = result_csv_path_checker.common_check()
333
- if mode == 'result':
334
- result_csv_name = os.path.basename(validated_result_csv_path)
335
- pattern = r"^accuracy_checking_result_\d{14}\.csv$"
336
- if not re.match(pattern, result_csv_name):
337
- raise ValueError("When continue run ut, please do not modify the result csv name.")
338
- return validated_result_csv_path
339
-
340
-
341
- def get_validated_details_csv_path(validated_result_csv_path):
342
- result_csv_name = os.path.basename(validated_result_csv_path)
343
- details_csv_name = result_csv_name.replace('result', 'details')
344
- details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
345
- details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
346
- ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
347
- validated_details_csv_path = details_csv_path_checker.common_check()
348
- return validated_details_csv_path
412
+ def init_attl(config):
413
+ """config: OnlineConfig"""
414
+ attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True,
415
+ connect_ip=config.host,
416
+ connect_port=config.port,
417
+ nfs_path=config.nfs_path,
418
+ tls_path=config.tls_path))
419
+ return attl
349
420
 
350
421
 
351
422
  def _run_ut_parser(parser):
352
423
  parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str,
353
- help="<Required> The api param tool result file: generate from api param tool, "
424
+ help="<Optional> The api param tool result file: generate from api param tool, "
354
425
  "a json file.",
355
- required=True)
426
+ required=False)
356
427
  parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str,
357
428
  help="<optional> The ut task result out path.",
358
429
  required=False)
@@ -378,12 +449,10 @@ def _run_ut_parser(parser):
378
449
  help="<optional> The path of accuracy_checking_result_{timestamp}.csv, "
379
450
  "when run ut is interrupted, enter the file path to continue run ut.",
380
451
  required=False)
381
- parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str,
382
- help="<optional> In real data mode, the root directory for storing real data "
383
- "must be configured.",
384
- required=False)
385
452
  parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
386
453
  help="<optional> Whether to filter the api in the api_info_file.", required=False)
454
+ parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
455
+ help="<optional> The path of config.json", required=False)
387
456
 
388
457
 
389
458
  def preprocess_forward_content(forward_content):
@@ -397,9 +466,9 @@ def preprocess_forward_content(forward_content):
397
466
  if key not in arg_cache:
398
467
  filtered_new_args = [
399
468
  {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
400
- for arg in value['args'] if isinstance(arg, dict)
469
+ for arg in value['input_args'] if isinstance(arg, dict)
401
470
  ]
402
- arg_cache[key] = (filtered_new_args, value['kwargs'])
471
+ arg_cache[key] = (filtered_new_args, value['input_kwargs'])
403
472
 
404
473
  filtered_new_args, new_kwargs = arg_cache[key]
405
474
 
@@ -444,50 +513,69 @@ def run_ut_command(args):
444
513
  except Exception as error:
445
514
  logger.error(f"Set device id failed. device id is: {args.device_id}")
446
515
  raise NotImplementedError from error
447
- check_link(args.api_info_file)
448
- api_info = os.path.realpath(args.api_info_file)
449
- check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
516
+
517
+ # 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None
518
+ # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
519
+ forward_content, backward_content, real_data_path = None, None, None
520
+ if args.api_info_file:
521
+ api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
522
+ ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
523
+ checked_api_info = api_info_file_checker.common_check()
524
+ forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
525
+ if args.filter_api:
526
+ logger.info("Start filtering the api in the forward_input_file.")
527
+ forward_content = preprocess_forward_content(forward_content)
528
+ logger.info("Finish filtering the api in the forward_input_file.")
529
+
450
530
  out_path = os.path.realpath(args.out_path) if args.out_path else "./"
451
531
  check_path_before_create(out_path)
452
532
  create_directory(out_path)
453
533
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
454
534
  out_path = out_path_checker.common_check()
455
535
  save_error_data = args.save_error_data
456
- forward_content, backward_content, real_data_path = parse_json_info_forward_backward(api_info)
457
- if args.filter_api:
458
- logger.info("Start filtering the api in the forward_input_file.")
459
- forward_content = preprocess_forward_content(forward_content)
460
- logger.info("Finish filtering the api in the forward_input_file.")
461
536
 
462
537
  result_csv_path = os.path.join(out_path, RESULT_FILE_NAME)
463
538
  details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME)
464
539
  if args.result_csv_path:
465
540
  result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
466
541
  details_csv_path = get_validated_details_csv_path(result_csv_path)
542
+ white_list = msCheckerConfig.white_list
543
+ black_list = msCheckerConfig.black_list
544
+ error_data_path = msCheckerConfig.error_data_path
545
+ is_online = msCheckerConfig.is_online
546
+ nfs_path = msCheckerConfig.nfs_path
547
+ host = msCheckerConfig.host
548
+ port = msCheckerConfig.port
549
+ rank_list = msCheckerConfig.rank_list
550
+ tls_path = msCheckerConfig.tls_path
551
+ if args.config_path:
552
+ config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
553
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
554
+ checked_config_path = config_path_checker.common_check()
555
+ _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
556
+ white_list = task_config.white_list
557
+ black_list = task_config.black_list
558
+ error_data_path = task_config.error_data_path
559
+ is_online = task_config.is_online
560
+ nfs_path = task_config.nfs_path
561
+ host = task_config.host
562
+ port = task_config.port
563
+ rank_list = task_config.rank_list
564
+ tls_path = task_config.tls_path
565
+
467
566
  if save_error_data:
468
567
  if args.result_csv_path:
469
568
  time_info = result_csv_path.split('.')[0].split('_')[-1]
470
569
  global UT_ERROR_DATA_DIR
471
570
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
472
- initialize_save_error_data()
571
+ error_data_path = initialize_save_error_data(error_data_path)
572
+ online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
473
573
  run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
474
- args.result_csv_path, real_data_path)
574
+ args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
575
+ online_config)
475
576
  run_ut(run_ut_config)
476
577
 
477
578
 
478
- class UtDataInfo:
479
- def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
480
- backward_message, rank=0):
481
- self.bench_grad = bench_grad
482
- self.device_grad = device_grad
483
- self.device_output = device_output
484
- self.bench_output = bench_output
485
- self.grad_in = grad_in
486
- self.in_fwd_data_list = in_fwd_data_list
487
- self.backward_message = backward_message
488
- self.rank = rank
489
-
490
-
491
579
  if __name__ == '__main__':
492
580
  _run_ut()
493
581
  logger.info("UT task completed.")
@@ -1,7 +1,74 @@
1
+ import os
2
+ import re
3
+
4
+ from msprobe.core.common.const import FileCheckConst
5
+ from msprobe.core.common.file_check import FileChecker
6
+ from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
+ from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
+ from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
9
+ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
+ from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
+
1
12
  hf_32_standard_api = ["conv1d", "conv2d"]
2
13
 
3
14
 
4
15
  class Backward_Message:
5
16
  MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
6
17
  UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
7
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
18
+ NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
19
+
20
+
21
+ class UtDataInfo:
22
+ def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
23
+ backward_message, rank=0):
24
+ self.bench_grad = bench_grad
25
+ self.device_grad = device_grad
26
+ self.device_output = device_output
27
+ self.bench_output = bench_output
28
+ self.grad_in = grad_in
29
+ self.in_fwd_data_list = in_fwd_data_list
30
+ self.backward_message = backward_message
31
+ self.rank = rank
32
+
33
+
34
+ def get_validated_result_csv_path(result_csv_path, mode):
35
+ if mode not in ['result', 'detail']:
36
+ raise ValueError("The csv mode must be result or detail")
37
+ result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
38
+ file_type=FileCheckConst.CSV_SUFFIX)
39
+ validated_result_csv_path = result_csv_path_checker.common_check()
40
+ if mode == 'result':
41
+ result_csv_name = os.path.basename(validated_result_csv_path)
42
+ pattern = r"^accuracy_checking_result_\d{14}\.csv$"
43
+ if not re.match(pattern, result_csv_name):
44
+ raise ValueError("When continue run ut, please do not modify the result csv name.")
45
+ return validated_result_csv_path
46
+
47
+
48
+ def get_validated_details_csv_path(validated_result_csv_path):
49
+ result_csv_name = os.path.basename(validated_result_csv_path)
50
+ details_csv_name = result_csv_name.replace('result', 'details')
51
+ details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
52
+ details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
53
+ ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
54
+ validated_details_csv_path = details_csv_path_checker.common_check()
55
+ return validated_details_csv_path
56
+
57
+
58
+ def exec_api(api_type, api_name, args, kwargs):
59
+ if api_type == "Functional":
60
+ functional_api = FunctionalOPTemplate(api_name, str, False)
61
+ out = functional_api.forward(*args, **kwargs)
62
+ if api_type == "Tensor":
63
+ tensor_api = TensorOPTemplate(api_name, str, False)
64
+ out = tensor_api.forward(*args, **kwargs)
65
+ if api_type == "Torch":
66
+ torch_api = TorchOPTemplate(api_name, str, False)
67
+ out = torch_api.forward(*args, **kwargs)
68
+ if api_type == "Aten":
69
+ torch_api = AtenOPTemplate(api_name, None, False)
70
+ out = torch_api.forward(*args, **kwargs)
71
+ if api_type == "NPU":
72
+ torch_api = NpuOPTemplate(api_name, None, False)
73
+ out = torch_api.forward(*args, **kwargs)
74
+ return out