mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import subprocess
2
19
  import json
3
20
  import os
@@ -16,9 +33,10 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
16
33
  from msprobe.pytorch.common import parse_json_info_forward_backward
17
34
  from msprobe.pytorch.common.log import logger
18
35
  from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
19
- check_path_before_create, create_directory
36
+ create_directory, load_json, save_json
20
37
  from msprobe.core.common.file_utils import remove_path
21
- from msprobe.core.common.const import FileCheckConst
38
+ from msprobe.core.common.const import FileCheckConst, Const
39
+ from msprobe.core.common.utils import CompareException
22
40
 
23
41
 
24
42
  def split_json_file(input_file, num_splits, filter_api):
@@ -30,9 +48,11 @@ def split_json_file(input_file, num_splits, filter_api):
30
48
  for data_name in list(backward_data.keys()):
31
49
  backward_data[f"{data_name}.backward"] = backward_data.pop(data_name)
32
50
 
33
- with FileOpen(input_file, 'r') as file:
34
- input_data = json.load(file)
35
- input_data.pop("data")
51
+ input_data = load_json(input_file)
52
+ if input_data.get("data") is None:
53
+ logger.error("Invalid input file, 'data' field is missing")
54
+ raise CompareException("Invalid input file, 'data' field is missing")
55
+ input_data.pop("data")
36
56
 
37
57
  items = list(forward_data.items())
38
58
  total_items = len(items)
@@ -52,8 +72,7 @@ def split_json_file(input_file, num_splits, filter_api):
52
72
  }
53
73
  }
54
74
  split_filename = f"temp_part{i}.json"
55
- with FileOpen(split_filename, 'w') as split_file:
56
- json.dump(temp_data, split_file)
75
+ save_json(split_filename, temp_data)
57
76
  split_files.append(split_filename)
58
77
 
59
78
  return split_files, total_items
@@ -105,7 +124,7 @@ def run_parallel_ut(config):
105
124
  if output == '':
106
125
  break
107
126
  if '[ERROR]' in output:
108
- print(output, end='')
127
+ logger.warning(output)
109
128
  sys.stdout.flush()
110
129
  except ValueError as e:
111
130
  logger.warning(f"An error occurred while reading subprocess output: {e}")
@@ -119,7 +138,8 @@ def run_parallel_ut(config):
119
138
 
120
139
  for api_info in config.api_files:
121
140
  cmd = create_cmd(api_info, next(device_id_cycle))
122
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False)
141
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
142
+ text=True, bufsize=1, shell=False)
123
143
  processes.append(process)
124
144
  threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
125
145
 
@@ -150,7 +170,8 @@ def run_parallel_ut(config):
150
170
  logger.error(f"An unexpected error occurred: {e}")
151
171
  finally:
152
172
  if progress_bar.n < config.total_items:
153
- logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.")
173
+ logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
174
+ "the result CSV file will be utilized to resume the UT task.")
154
175
  clean_up()
155
176
  progress_bar_thread.join()
156
177
  try:
@@ -163,17 +184,21 @@ def run_parallel_ut(config):
163
184
 
164
185
 
165
186
  def prepare_config(args):
166
- check_link(args.api_info_file)
167
- api_info = os.path.realpath(args.api_info_file)
168
- check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX)
169
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
170
- check_path_before_create(out_path)
187
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
188
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
189
+ api_info = api_info_file_checker.common_check()
190
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
171
191
  create_directory(out_path)
172
192
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
173
193
  out_path = out_path_checker.common_check()
174
194
  split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
175
- config_path = os.path.realpath(args.config_path) if args.config_path else None
176
- result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
195
+ config_path = args.config_path if args.config_path else None
196
+ if config_path:
197
+ config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
198
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
199
+ config_path = config_path_checker.common_check()
200
+ result_csv_path = args.result_csv_path or os.path.join(
201
+ out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
177
202
  if not args.result_csv_path:
178
203
  details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
179
204
  comparator = Comparator(result_csv_path, details_csv_path, False)
@@ -190,7 +215,8 @@ def prepare_config(args):
190
215
  def main():
191
216
  parser = argparse.ArgumentParser(description='Run UT in parallel')
192
217
  _run_ut_parser(parser)
193
- parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64')
218
+ parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
219
+ help='Number of splits for parallel processing. Range: 1-64')
194
220
  args = parser.parse_args()
195
221
  config = prepare_config(args)
196
222
  run_parallel_ut(config)
@@ -1,3 +1,20 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import argparse
2
19
  import os
3
20
  import sys
@@ -11,11 +28,12 @@ else:
11
28
  import torch
12
29
  from tqdm import tqdm
13
30
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
14
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
15
- from msprobe.core.common.file_utils import check_link
31
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api
32
+ from msprobe.core.common.file_utils import check_link, FileChecker
33
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
34
+ from msprobe.core.common.const import FileCheckConst, Const
16
35
  from msprobe.pytorch.common.log import logger
17
36
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
18
- from msprobe.core.common.const import Const
19
37
 
20
38
 
21
39
  def check_tensor_overflow(x):
@@ -24,8 +42,8 @@ def check_tensor_overflow(x):
24
42
  tensor_max = x.cpu().detach().float().numpy().tolist()
25
43
  tensor_min = tensor_max
26
44
  else:
27
- tensor_max = torch._C._VariableFunctionsClass.max(x).cpu().detach().float().numpy().tolist()
28
- tensor_min = torch._C._VariableFunctionsClass.min(x).cpu().detach().float().numpy().tolist()
45
+ tensor_max = torch.max(x).cpu().detach().float().numpy().tolist()
46
+ tensor_min = torch.min(x).cpu().detach().float().numpy().tolist()
29
47
  # inf
30
48
  if tensor_max == float('inf') or tensor_min == float('-inf'):
31
49
  return True
@@ -57,23 +75,25 @@ def run_overflow_check(forward_file):
57
75
  logger.info("start UT test")
58
76
  forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
59
77
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
78
+ if is_unsupported_api(api_full_name, is_overflow_check=True):
79
+ continue
60
80
  try:
61
81
  run_torch_api(api_full_name, api_info_dict, real_data_path)
62
82
  except Exception as err:
63
83
  _, api_name, _ = api_full_name.split(Const.SEP)
64
84
  if "not implemented for 'Half'" in str(err):
65
- logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
66
- f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
85
+ logger.warning(f"API {api_name} not support half tensor in CPU. This API does not support overflow "
86
+ "check, so it will be skipped.")
67
87
  elif "expected scalar type Long" in str(err):
68
88
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
69
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
89
+ "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
70
90
  else:
71
91
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
72
92
 
73
93
 
74
94
  def run_torch_api(api_full_name, api_info_dict, real_data_path):
75
95
  torch.npu.clear_npu_overflow_flag()
76
- api_type, api_name, _ = api_full_name.split(Const.SEP)
96
+ api_type, api_name = extract_basic_api_segments(api_full_name)
77
97
  args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
78
98
  if not need_grad:
79
99
  logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
@@ -118,8 +138,9 @@ def _run_overflow_check(parser=None):
118
138
  def _run_overflow_check_command(args):
119
139
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
120
140
  npu_device = "npu:" + str(args.device_id)
121
- check_link(args.api_info_file)
122
- api_info = os.path.realpath(args.api_info_file)
141
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
142
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
143
+ api_info = api_info_file_checker.common_check()
123
144
  try:
124
145
  torch.npu.set_device(npu_device)
125
146
  except Exception as error:
@@ -1,6 +1,23 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import argparse
2
19
  import os
3
- import csv
20
+ import re
4
21
  import sys
5
22
  import time
6
23
  import gc
@@ -17,43 +34,34 @@ else:
17
34
  import torch
18
35
  from tqdm import tqdm
19
36
 
20
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api, UtDataInfo, \
21
- get_validated_result_csv_path, get_validated_details_csv_path, exec_api
37
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
38
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
22
39
  from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
23
40
  from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
24
41
  initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
25
42
  from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
26
43
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
27
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
44
+ from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
28
45
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
29
- from msprobe.core.common.file_utils import FileOpen, FileChecker, \
30
- change_mode, check_path_before_create, create_directory, get_json_contents
46
+ from msprobe.core.common.file_utils import FileChecker, change_mode, \
47
+ create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
31
48
  from msprobe.pytorch.common.log import logger
32
49
  from msprobe.pytorch.pt_config import parse_json_config
33
50
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
51
+ from msprobe.core.common.utils import safe_get_value
34
52
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
35
53
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
54
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
36
55
 
37
56
 
38
57
  current_time = time.strftime("%Y%m%d%H%M%S")
39
58
  UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
40
59
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
41
60
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
42
- RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
43
- 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
44
- 'black_list', 'error_data_path', 'online_config'])
45
61
 
46
- OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
47
62
 
48
63
  not_backward_list = ['repeat_interleave']
49
- not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
50
- not_raise_dtype_set = {'type_as'}
51
64
 
52
- RAISE_PRECISION = {
53
- torch.float16: torch.float32,
54
- torch.bfloat16: torch.float32,
55
- torch.float32: torch.float64
56
- }
57
65
 
58
66
  tqdm_params = {
59
67
  'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
@@ -71,98 +79,6 @@ tqdm_params = {
71
79
  }
72
80
 
73
81
 
74
- def deal_detach(arg, to_detach=True):
75
- return arg.detach() if to_detach else arg
76
-
77
-
78
- def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
79
- '''
80
- 将标杆数据的dtype转换为raise_dtype
81
- 输入:
82
- api_name:api名称
83
- arg:标杆输入
84
- raise_dtype:需要转换的dtype
85
- 输出:
86
- arg: 转换dtype的标杆输入
87
- '''
88
- if api_name in hf_32_standard_api and arg.dtype == torch.float32:
89
- return arg
90
- if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
91
- return arg
92
- return arg.type(raise_dtype)
93
-
94
-
95
- def generate_device_params(input_args, input_kwargs, need_backward, api_name):
96
- def recursive_arg_to_device(arg_in, to_detach):
97
- if isinstance(arg_in, (list, tuple)):
98
- return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in)
99
- elif isinstance(arg_in, torch.Tensor):
100
- if need_backward and arg_in.requires_grad:
101
- arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
102
- temp_arg_in = arg_in * 1
103
- arg_in = temp_arg_in.type_as(arg_in)
104
- arg_in.retain_grad()
105
- return arg_in
106
- else:
107
- return deal_detach(arg_in.clone(), to_detach).to(current_device)
108
- else:
109
- return arg_in
110
-
111
- is_detach = api_name not in not_detach_set
112
- device_args = recursive_arg_to_device(input_args, is_detach)
113
- device_kwargs = \
114
- {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
115
- return device_args, device_kwargs
116
-
117
-
118
- def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
119
- def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None):
120
- if isinstance(arg_in, (list, tuple)):
121
- return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in)
122
- elif isinstance(arg_in, torch.Tensor):
123
- if need_backward and arg_in.requires_grad:
124
- arg_in = deal_detach(raise_bench_data_dtype(
125
- api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
126
- temp_arg_in = arg_in * 1
127
- arg_in = temp_arg_in.type_as(arg_in)
128
- arg_in.retain_grad()
129
- return arg_in
130
- else:
131
- return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
132
- else:
133
- return arg_in
134
-
135
- def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
136
- if arg_in.dtype in RAISE_PRECISION:
137
- return True
138
- if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
139
- return True
140
- return False
141
-
142
- def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False):
143
- if isinstance(arg_in, (list, tuple)):
144
- return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs) for arg in arg_in))
145
- elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
146
- return set([arg_in.dtype])
147
- elif isinstance(arg_in, dict) and check_kwargs:
148
- return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True) for v in arg_in.values()))
149
- return set()
150
-
151
- raise_dtype = None
152
- need_raise_dtypes = recursive_find_dtypes(input_args)
153
- need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
154
- if len(need_raise_dtypes) == 1:
155
- raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
156
- elif len(need_raise_dtypes) >= 2:
157
- raise_dtype = torch.float32
158
-
159
- raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
160
- is_detach = api_name not in not_detach_set
161
- cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
162
- cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
163
- return cpu_args, cpu_kwargs
164
-
165
-
166
82
  def run_ut(config):
167
83
  logger.info("start UT test")
168
84
  if config.online_config.is_online:
@@ -179,10 +95,12 @@ def run_ut(config):
179
95
  if config.online_config.is_online:
180
96
  run_api_online(config, compare)
181
97
  else:
182
- with FileOpen(config.result_csv_path, 'r') as file:
183
- csv_reader = csv.reader(file)
184
- next(csv_reader)
185
- api_name_set = {row[0] for row in csv_reader}
98
+ csv_df = read_csv(config.result_csv_path)
99
+ try:
100
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
101
+ except IndexError:
102
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
103
+ api_name_set = set()
186
104
  run_api_offline(config, compare, api_name_set)
187
105
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
188
106
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -198,17 +116,23 @@ def run_api_offline(config, compare, api_name_set):
198
116
  if api_full_name in api_name_set:
199
117
  continue
200
118
  if is_unsupported_api(api_full_name):
119
+ skip_message = f"API {api_full_name} not support for run ut. SKIP."
120
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
121
+ record_skip_info(api_full_name, compare, compare_alg_results)
201
122
  continue
202
123
  _, api_name = extract_basic_api_segments(api_full_name)
203
124
  if not api_name:
204
125
  err_message = f"API {api_full_name} not support for run ut. SKIP."
205
126
  logger.error(err_message)
206
- fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
207
- result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
208
- compare.record_results(result_info)
127
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, err_message)
128
+ record_skip_info(api_full_name, compare, compare_alg_results)
209
129
  continue
210
130
  try:
211
131
  if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list):
132
+ skip_message = f"API {api_name} in black list or not in white list. SKIP."
133
+ logger.info(skip_message)
134
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, skip_message)
135
+ record_skip_info(api_full_name, compare, compare_alg_results)
212
136
  continue
213
137
  data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict)
214
138
  is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info)
@@ -217,12 +141,11 @@ def run_api_offline(config, compare, api_name_set):
217
141
  except Exception as err:
218
142
  if "expected scalar type Long" in str(err):
219
143
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
220
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
144
+ "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
221
145
  else:
222
146
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
223
- fwd_compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
224
- result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [fwd_compare_alg_results], None, 0)
225
- compare.record_results(result_info)
147
+ compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
148
+ record_skip_info(api_full_name, compare, compare_alg_results)
226
149
  finally:
227
150
  if is_gpu:
228
151
  torch.cuda.empty_cache()
@@ -298,14 +221,6 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
298
221
  return False
299
222
 
300
223
 
301
- def is_unsupported_api(api_name):
302
- split_name = api_name.split(Const.SEP)[0]
303
- flag = split_name == Const.DISTRIBUTED
304
- if flag:
305
- logger.info(f"{split_name} api is not supported for run ut. SKIP.")
306
- return flag
307
-
308
-
309
224
  def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
310
225
  if not is_fwd_success or not is_bwd_success:
311
226
  processor = UtDataProcessor(error_data_path)
@@ -327,12 +242,12 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
327
242
  in_fwd_data_list.append(kwargs)
328
243
  need_backward = api_full_name in backward_content
329
244
  if not need_grad:
330
- logger.warning("%s %s" % (api_full_name, Backward_Message.UNSUPPORT_BACKWARD_MESSAGE))
331
- backward_message += Backward_Message.UNSUPPORT_BACKWARD_MESSAGE
245
+ logger.warning("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE))
246
+ backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
332
247
  if api_name in not_backward_list:
333
248
  need_grad = False
334
- logger.warning("%s %s" % (api_full_name, Backward_Message.NO_BACKWARD_RESULT_MESSAGE))
335
- backward_message += Backward_Message.NO_BACKWARD_RESULT_MESSAGE
249
+ logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
250
+ backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
336
251
  need_backward = need_backward and need_grad
337
252
  if kwargs.get("device"):
338
253
  del kwargs["device"]
@@ -353,16 +268,20 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
353
268
  if need_backward:
354
269
  if need_to_backward(grad_index, out):
355
270
  backward_args = backward_content[api_full_name].get("input")
356
- grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
271
+ func_options = {
272
+ 'real_data_path': real_data_path
273
+ }
274
+ grad = gen_args(backward_args, api_name, func_options)
275
+ grad = safe_get_value(grad, 0, "grad")
357
276
  bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
358
277
  bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
359
278
  device_grad = grad.clone().detach().to(current_device)
360
279
  device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
361
280
  else:
362
- backward_message += Backward_Message.MULTIPLE_BACKWARD_MESSAGE
281
+ backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
363
282
  if api_name == "npu_fusion_attention":
364
- out = out[0]
365
- device_out = device_out[0]
283
+ out = safe_get_value(out, 0, "out")
284
+ device_out = safe_get_value(device_out, 0, "device_out")
366
285
 
367
286
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
368
287
 
@@ -398,6 +317,9 @@ def need_to_backward(grad_index, out):
398
317
 
399
318
  def run_backward(args, grad, grad_index, out):
400
319
  if grad_index is not None:
320
+ if grad_index >= len(out):
321
+ logger.error(f"Run backward error when grad_index is {grad_index}")
322
+ raise IndexError(f"Run backward error when grad_index is {grad_index}")
401
323
  out[grad_index].backward(grad)
402
324
  else:
403
325
  out.backward(grad)
@@ -411,12 +333,11 @@ def run_backward(args, grad, grad_index, out):
411
333
 
412
334
 
413
335
  def initialize_save_error_data(error_data_path):
414
- check_path_before_create(error_data_path)
415
336
  create_directory(error_data_path)
416
337
  error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
417
338
  ability=FileCheckConst.WRITE_ABLE)
418
339
  error_data_path = error_data_path_checker.common_check()
419
- error_data_path =initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
340
+ error_data_path = initialize_save_path(error_data_path, UT_ERROR_DATA_DIR)
420
341
  return error_data_path
421
342
 
422
343
 
@@ -477,7 +398,8 @@ def preprocess_forward_content(forward_content):
477
398
  if key not in arg_cache:
478
399
  filtered_new_args = [
479
400
  {k: v for k, v in arg.items() if k not in ['Max', 'Min']}
480
- for arg in value['input_args'] if isinstance(arg, dict)
401
+ for arg in value['input_args']
402
+ if isinstance(arg, dict)
481
403
  ]
482
404
  arg_cache[key] = (filtered_new_args, value['input_kwargs'])
483
405
 
@@ -512,7 +434,49 @@ def _run_ut(parser=None):
512
434
  run_ut_command(args)
513
435
 
514
436
 
437
+ def checked_online_config(online_config):
438
+ if not online_config.is_online:
439
+ return
440
+ if not isinstance(online_config.is_online, bool):
441
+ raise ValueError("is_online must be bool type")
442
+ # rank_list
443
+ if not isinstance(online_config.rank_list, list):
444
+ raise ValueError("rank_list must be a list")
445
+ if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
446
+ raise ValueError("All elements in rank_list must be integers")
447
+
448
+ # nfs_path
449
+ if online_config.nfs_path:
450
+ check_file_or_directory_path(online_config.nfs_path, isdir=True)
451
+ return
452
+ # tls_path
453
+ if online_config.tls_path:
454
+ check_file_or_directory_path(online_config.tls_path, isdir=True)
455
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
456
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
457
+ check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
458
+
459
+ # host and port
460
+ if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
461
+ raise Exception(f"host: {online_config.host} is invalid.")
462
+ if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
463
+ raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
464
+
465
+
515
466
  def run_ut_command(args):
467
+ if args.config_path:
468
+ config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
469
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
470
+ checked_config_path = config_path_checker.common_check()
471
+ _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
472
+ checker_config = CheckerConfig(task_config)
473
+ else:
474
+ checker_config = CheckerConfig()
475
+
476
+ if not checker_config.is_online and not args.api_info_file:
477
+ logger.error("Please provide api_info_file for offline run ut.")
478
+ raise Exception("Please provide api_info_file for offline run ut.")
479
+
516
480
  if not is_gpu:
517
481
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
518
482
  used_device = current_device + ":" + str(args.device_id[0])
@@ -529,17 +493,16 @@ def run_ut_command(args):
529
493
  # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析
530
494
  forward_content, backward_content, real_data_path = None, None, None
531
495
  if args.api_info_file:
532
- api_info_file_checker = FileChecker(file_path = args.api_info_file, path_type = FileCheckConst.FILE,
533
- ability = FileCheckConst.READ_ABLE, file_type = FileCheckConst.JSON_SUFFIX)
496
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
497
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
534
498
  checked_api_info = api_info_file_checker.common_check()
535
499
  forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
536
500
  if args.filter_api:
537
- logger.info("Start filtering the api in the forward_input_file.")
501
+ logger.info("Start filtering the api in the api_info_file.")
538
502
  forward_content = preprocess_forward_content(forward_content)
539
- logger.info("Finish filtering the api in the forward_input_file.")
503
+ logger.info("Finish filtering the api in the api_info_file.")
540
504
 
541
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
542
- check_path_before_create(out_path)
505
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
543
506
  create_directory(out_path)
544
507
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
545
508
  out_path = out_path_checker.common_check()
@@ -550,40 +513,27 @@ def run_ut_command(args):
550
513
  if args.result_csv_path:
551
514
  result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
552
515
  details_csv_path = get_validated_details_csv_path(result_csv_path)
553
- white_list = msCheckerConfig.white_list
554
- black_list = msCheckerConfig.black_list
555
- error_data_path = msCheckerConfig.error_data_path
556
- is_online = msCheckerConfig.is_online
557
- nfs_path = msCheckerConfig.nfs_path
558
- host = msCheckerConfig.host
559
- port = msCheckerConfig.port
560
- rank_list = msCheckerConfig.rank_list
561
- tls_path = msCheckerConfig.tls_path
562
- if args.config_path:
563
- config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
564
- FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
565
- checked_config_path = config_path_checker.common_check()
566
- _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
567
- white_list = task_config.white_list
568
- black_list = task_config.black_list
569
- error_data_path = task_config.error_data_path
570
- is_online = task_config.is_online
571
- nfs_path = task_config.nfs_path
572
- host = task_config.host
573
- port = task_config.port
574
- rank_list = task_config.rank_list
575
- tls_path = task_config.tls_path
576
516
 
517
+ error_data_path = checker_config.error_data_path
577
518
  if save_error_data:
578
519
  if args.result_csv_path:
579
520
  time_info = result_csv_path.split('.')[0].split('_')[-1]
580
521
  global UT_ERROR_DATA_DIR
581
522
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
582
523
  error_data_path = initialize_save_error_data(error_data_path)
583
- online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
584
- run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
585
- args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
586
- online_config)
524
+ online_config = checker_config.get_online_config()
525
+ checked_online_config(online_config)
526
+ config_params = {
527
+ 'forward_content': forward_content,
528
+ 'backward_content': backward_content,
529
+ 'result_csv_path': result_csv_path,
530
+ 'details_csv_path': details_csv_path,
531
+ 'save_error_data': save_error_data,
532
+ 'is_continue_run_ut': args.result_csv_path,
533
+ 'real_data_path': real_data_path,
534
+ 'error_data_path': error_data_path
535
+ }
536
+ run_ut_config = checker_config.get_run_ut_config(**config_params)
587
537
  run_ut(run_ut_config)
588
538
 
589
539