mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -23,16 +23,19 @@ try:
23
23
  import torch_npu
24
24
  except ImportError:
25
25
  is_gpu = True
26
+ current_device = "cuda"
26
27
  else:
27
28
  is_gpu = False
29
+ current_device = "npu"
28
30
  import torch
29
31
  from tqdm import tqdm
30
32
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
31
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
32
- from msprobe.core.common.file_utils import check_link
33
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api, is_unsupported_api, ExecParams
34
+ from msprobe.core.common.file_utils import check_link, FileChecker
35
+ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments
36
+ from msprobe.core.common.const import FileCheckConst, Const
33
37
  from msprobe.pytorch.common.log import logger
34
38
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
35
- from msprobe.core.common.const import Const
36
39
 
37
40
 
38
41
  def check_tensor_overflow(x):
@@ -60,52 +63,80 @@ def check_tensor_overflow(x):
60
63
  return False
61
64
 
62
65
 
63
- def check_data_overflow(x):
64
- if isinstance(x, (tuple, list)) and x:
65
- for _, item in enumerate(x):
66
- if check_data_overflow(item):
67
- return True
68
- return False
66
+ def check_data_overflow(x, device):
67
+ if isinstance(x, (tuple, list)):
68
+ if not x:
69
+ return False
70
+ return any(check_data_overflow(item, device) for item in x)
69
71
  else:
70
- return check_tensor_overflow(x)
72
+ if device == Const.CPU_LOWERCASE:
73
+ return check_tensor_overflow(x)
74
+ else:
75
+ return torch_npu.npu.utils.npu_check_overflow(x)
76
+
77
+
78
+ def is_bool_output(x):
79
+ if isinstance(x, (tuple, list)):
80
+ if not x:
81
+ return False
82
+ return any(is_bool_output(item) for item in x)
83
+ else:
84
+ return isinstance(x, bool)
71
85
 
72
86
 
73
87
  def run_overflow_check(forward_file):
74
88
  logger.info("start UT test")
75
89
  forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
90
+ if real_data_path:
91
+ dump_path = os.path.dirname(forward_file)
92
+ real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
76
93
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
94
+ if is_unsupported_api(api_full_name, is_overflow_check=True):
95
+ continue
77
96
  try:
78
97
  run_torch_api(api_full_name, api_info_dict, real_data_path)
79
98
  except Exception as err:
80
99
  _, api_name, _ = api_full_name.split(Const.SEP)
81
100
  if "not implemented for 'Half'" in str(err):
82
- logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
83
- f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
101
+ logger.warning(f"API {api_name} not support half tensor in CPU. This API does not support overflow "
102
+ "check, so it will be skipped.")
84
103
  elif "expected scalar type Long" in str(err):
85
104
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
86
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
105
+ "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
106
+ elif "could not create a primitive descriptor for a matmul primitive" in str(err):
107
+ logger.warning(f"API {api_name} not support matmul primitive in CPU due to pytorch bug, "
108
+ "so it will be skipped.")
87
109
  else:
88
110
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
89
111
 
90
112
 
91
113
  def run_torch_api(api_full_name, api_info_dict, real_data_path):
92
114
  torch.npu.clear_npu_overflow_flag()
93
- api_type, api_name, _ = api_full_name.split(Const.SEP)
115
+ api_type, api_name = extract_basic_api_segments(api_full_name)
94
116
  args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
95
117
  if not need_grad:
96
118
  logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
97
119
  % api_full_name)
120
+ device_info_kwargs = kwargs.get(Const.DEVICE)
121
+ if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
122
+ kwargs[Const.DEVICE] = current_device
98
123
  npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name)
99
- if kwargs.get("device"):
100
- del kwargs["device"]
101
- out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs)
102
- npu_out = exec_api(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs)
124
+ if kwargs.get(Const.DEVICE):
125
+ del kwargs[Const.DEVICE]
126
+ cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, args, kwargs, False, None)
127
+ device_exec_params = ExecParams(api_type, api_name, Const.NPU_LOWERCASE, npu_args, npu_kwargs, False, None)
128
+ out = exec_api(cpu_exec_params)
129
+ npu_out = exec_api(device_exec_params)
103
130
  if out is None and npu_out is None:
104
131
  logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
105
132
  return
133
+ if is_bool_output(out) or is_bool_output(npu_out):
134
+ logger.warning("The output of %s is bool type.This dtype not support overflow, so it will be skipped."
135
+ % api_full_name)
136
+ return
106
137
 
107
- cpu_overflow = check_data_overflow(out)
108
- npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
138
+ cpu_overflow = check_data_overflow(out, Const.CPU_LOWERCASE)
139
+ npu_overflow = check_data_overflow(npu_out, Const.NPU_LOWERCASE)
109
140
  if cpu_overflow == npu_overflow:
110
141
  logger.warning("The %s overflow is a normal overflow." % api_full_name)
111
142
  else:
@@ -135,8 +166,9 @@ def _run_overflow_check(parser=None):
135
166
  def _run_overflow_check_command(args):
136
167
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
137
168
  npu_device = "npu:" + str(args.device_id)
138
- check_link(args.api_info_file)
139
- api_info = os.path.realpath(args.api_info_file)
169
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
170
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
171
+ api_info = api_info_file_checker.common_check()
140
172
  try:
141
173
  torch.npu.set_device(npu_device)
142
174
  except Exception as error:
@@ -17,7 +17,7 @@
17
17
 
18
18
  import argparse
19
19
  import os
20
- import csv
20
+ import re
21
21
  import sys
22
22
  import time
23
23
  import gc
@@ -31,39 +31,40 @@ except ImportError:
31
31
  else:
32
32
  is_gpu = False
33
33
  current_device = "npu"
34
+
34
35
  import torch
35
36
  from tqdm import tqdm
36
37
 
37
38
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import BackwardMessage, UtDataInfo, \
38
- get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info
39
+ get_validated_result_csv_path, get_validated_details_csv_path, exec_api, record_skip_info, is_unsupported_api
39
40
  from msprobe.pytorch.api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args
40
41
  from msprobe.pytorch.api_accuracy_checker.common.utils import api_info_preprocess, \
41
42
  initialize_save_path, UtDataProcessor, extract_basic_api_segments, ApiData
42
43
  from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
43
44
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
44
- from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
45
+ from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig
45
46
  from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
46
- from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, \
47
- create_directory, get_json_contents, read_csv
47
+ from msprobe.core.common.file_utils import FileChecker, change_mode, \
48
+ create_directory, get_json_contents, read_csv, check_file_or_directory_path, check_crt_valid
48
49
  from msprobe.pytorch.common.log import logger
49
50
  from msprobe.pytorch.pt_config import parse_json_config
50
51
  from msprobe.core.common.const import Const, FileCheckConst, CompareConst
52
+ from msprobe.core.common.utils import safe_get_value, CompareException
53
+ from msprobe.pytorch.common.utils import seed_all
51
54
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, move2device_exec
52
55
  from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher
53
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params
56
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_cpu_params, generate_device_params, \
57
+ ExecParams
54
58
 
55
59
 
56
60
  current_time = time.strftime("%Y%m%d%H%M%S")
57
61
  UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
58
62
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
59
63
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
60
- RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
61
- 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
62
- 'black_list', 'error_data_path', 'online_config'])
63
64
 
64
- OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
65
65
 
66
66
  not_backward_list = ['repeat_interleave']
67
+ unsupported_backward_list = ['masked_select']
67
68
 
68
69
 
69
70
  tqdm_params = {
@@ -99,7 +100,11 @@ def run_ut(config):
99
100
  run_api_online(config, compare)
100
101
  else:
101
102
  csv_df = read_csv(config.result_csv_path)
102
- api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
103
+ try:
104
+ api_name_set = {row[0] for row in csv_df.itertuples(index=False, name=None)}
105
+ except IndexError:
106
+ logger.error(f"Read {config.result_csv_path} error, api_name_set is empty.")
107
+ api_name_set = set()
103
108
  run_api_offline(config, compare, api_name_set)
104
109
  for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list):
105
110
  change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -140,7 +145,7 @@ def run_api_offline(config, compare, api_name_set):
140
145
  except Exception as err:
141
146
  if "expected scalar type Long" in str(err):
142
147
  logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API "
143
- f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
148
+ "'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.")
144
149
  else:
145
150
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
146
151
  compare_alg_results = err_column.to_column_value(CompareConst.SKIP, str(err))
@@ -220,14 +225,6 @@ def blacklist_and_whitelist_filter(api_name, black_list, white_list):
220
225
  return False
221
226
 
222
227
 
223
- def is_unsupported_api(api_name):
224
- split_name = api_name.split(Const.SEP)[0]
225
- flag = split_name == Const.DISTRIBUTED
226
- if flag:
227
- logger.info(f"{split_name} api is not supported for run ut. SKIP.")
228
- return flag
229
-
230
-
231
228
  def do_save_error_data(api_full_name, data_info, error_data_path, is_fwd_success, is_bwd_success):
232
229
  if not is_fwd_success or not is_bwd_success:
233
230
  processor = UtDataProcessor(error_data_path)
@@ -244,7 +241,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
244
241
  in_fwd_data_list = []
245
242
  backward_message = ''
246
243
  api_type, api_name = extract_basic_api_segments(api_full_name)
247
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
244
+ args, kwargs, output_dtype = get_api_info(api_info_dict, api_name, real_data_path)
245
+ need_grad = check_need_grad(api_info_dict)
248
246
  in_fwd_data_list.append(args)
249
247
  in_fwd_data_list.append(kwargs)
250
248
  need_backward = api_full_name in backward_content
@@ -253,16 +251,32 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
253
251
  backward_message += BackwardMessage.UNSUPPORT_BACKWARD_MESSAGE
254
252
  if api_name in not_backward_list:
255
253
  need_grad = False
256
- logger.warning("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
254
+ logger.info("%s %s" % (api_full_name, BackwardMessage.NO_BACKWARD_RESULT_MESSAGE))
257
255
  backward_message += BackwardMessage.NO_BACKWARD_RESULT_MESSAGE
256
+ if api_name in unsupported_backward_list:
257
+ need_grad = False
258
+ logger.info("%s %s" % (api_full_name, BackwardMessage.UNSUPPORT_API_MESSAGE))
259
+ backward_message += BackwardMessage.UNSUPPORT_API_MESSAGE
258
260
  need_backward = need_backward and need_grad
259
- if kwargs.get("device"):
260
- del kwargs["device"]
261
- cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name)
261
+
262
+ device_info_kwargs = kwargs.get(Const.DEVICE)
263
+ if device_info_kwargs and device_info_kwargs.get(Const.VALUE):
264
+ kwargs[Const.DEVICE] = current_device
262
265
  device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name)
266
+ if kwargs.get(Const.DEVICE):
267
+ del kwargs[Const.DEVICE]
268
+ cpu_params = generate_cpu_params(args, kwargs, need_backward, api_name)
269
+ cpu_args, cpu_kwargs = cpu_params.cpu_args, cpu_params.cpu_kwargs
270
+ autocast_dtype, is_autocast = cpu_params.autocast_dtype, cpu_params.is_autocast
271
+ if not is_autocast and output_dtype:
272
+ is_autocast = autocast_dtype != output_dtype
273
+ autocast_dtype = output_dtype
263
274
  bench_grad_out, device_grad_out = None, None
264
- out = exec_api(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs)
265
- device_out = exec_api(api_type, api_name, current_device, device_args, device_kwargs)
275
+ cpu_exec_params = ExecParams(api_type, api_name, Const.CPU_LOWERCASE, cpu_args, cpu_kwargs, False, autocast_dtype)
276
+ out = exec_api(cpu_exec_params)
277
+ device_exec_params = ExecParams(api_type, api_name, current_device, device_args, device_kwargs, is_autocast,
278
+ autocast_dtype)
279
+ device_out = exec_api(device_exec_params)
266
280
  current_path = os.path.dirname(os.path.realpath(__file__))
267
281
  ut_setting_path = os.path.join(current_path, "torch_ut_setting.json")
268
282
  api_setting_dict = get_json_contents(ut_setting_path)
@@ -278,16 +292,18 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
278
292
  func_options = {
279
293
  'real_data_path': real_data_path
280
294
  }
281
- grad = gen_args(backward_args, api_name, func_options)[0]
282
- bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
295
+ grad = gen_args(backward_args, api_name, func_options)
296
+ grad = safe_get_value(grad, 0, "grad")
297
+ grad_params = generate_cpu_params(grad, {}, False, api_name)
298
+ bench_grad = grad_params.cpu_args
283
299
  bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
284
300
  device_grad = grad.clone().detach().to(current_device)
285
301
  device_grad_out = run_backward(device_args, device_grad, grad_index, device_out)
286
302
  else:
287
303
  backward_message += BackwardMessage.MULTIPLE_BACKWARD_MESSAGE
288
304
  if api_name == "npu_fusion_attention":
289
- out = out[0]
290
- device_out = device_out[0]
305
+ out = safe_get_value(out, 0, "out")
306
+ device_out = safe_get_value(device_out, 0, "device_out")
291
307
 
292
308
  return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message)
293
309
 
@@ -306,13 +322,18 @@ def run_torch_api_online(api_full_name, api_data, backward_content):
306
322
  return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank)
307
323
 
308
324
 
309
- def get_api_info(api_info_dict, api_name, real_data_path):
310
- convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
325
+ def check_need_grad(api_info_dict):
311
326
  need_grad = True
312
- if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"):
327
+ if api_info_dict.get(Const.INPUT_KWARGS) and "out" in api_info_dict.get(Const.INPUT_KWARGS):
313
328
  need_grad = False
314
- args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
315
- return args, kwargs, need_grad
329
+ return need_grad
330
+
331
+
332
+ def get_api_info(api_info_dict, api_name, real_data_path):
333
+ convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)
334
+ need_grad = check_need_grad(api_info_dict)
335
+ args, kwargs, output_dtype = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path)
336
+ return args, kwargs, output_dtype
316
337
 
317
338
 
318
339
  def need_to_backward(grad_index, out):
@@ -323,20 +344,32 @@ def need_to_backward(grad_index, out):
323
344
 
324
345
  def run_backward(args, grad, grad_index, out):
325
346
  if grad_index is not None:
347
+ if grad_index >= len(out):
348
+ logger.error(f"Run backward error when grad_index is {grad_index}")
349
+ raise IndexError(f"Run backward error when grad_index is {grad_index}")
326
350
  out[grad_index].backward(grad)
327
351
  else:
328
352
  out.backward(grad)
329
- args_grad = []
330
- for arg in args:
331
- if isinstance(arg, torch.Tensor):
332
- args_grad.append(arg.grad)
333
- grad_out = args_grad
353
+
354
+ grad_out = extract_tensors_grad(args)
334
355
 
335
356
  return grad_out
336
357
 
337
358
 
359
+ def extract_tensors_grad(args, depth=0):
360
+ if depth > Const.MAX_DEPTH:
361
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
362
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
363
+ grads = []
364
+ for arg in args:
365
+ if isinstance(arg, torch.Tensor):
366
+ grads.append(arg.grad)
367
+ elif isinstance(arg, (list, tuple)):
368
+ grads.extend(extract_tensors_grad(arg, depth+1))
369
+ return grads
370
+
371
+
338
372
  def initialize_save_error_data(error_data_path):
339
- check_path_before_create(error_data_path)
340
373
  create_directory(error_data_path)
341
374
  error_data_path_checker = FileChecker(error_data_path, FileCheckConst.DIR,
342
375
  ability=FileCheckConst.WRITE_ABLE)
@@ -438,9 +471,55 @@ def _run_ut(parser=None):
438
471
  run_ut_command(args)
439
472
 
440
473
 
474
+ def checked_online_config(online_config):
475
+ if not online_config.is_online:
476
+ return
477
+ if not isinstance(online_config.is_online, bool):
478
+ raise ValueError("is_online must be bool type")
479
+ # rank_list
480
+ if not isinstance(online_config.rank_list, list):
481
+ raise ValueError("rank_list must be a list")
482
+ if online_config.rank_list and not all(isinstance(rank, int) for rank in online_config.rank_list):
483
+ raise ValueError("All elements in rank_list must be integers")
484
+
485
+ # nfs_path
486
+ if online_config.nfs_path:
487
+ check_file_or_directory_path(online_config.nfs_path, isdir=True)
488
+ return
489
+ # tls_path
490
+ if online_config.tls_path:
491
+ check_file_or_directory_path(online_config.tls_path, isdir=True)
492
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.key"))
493
+ check_file_or_directory_path(os.path.join(online_config.tls_path, "server.crt"))
494
+ check_crt_valid(os.path.join(online_config.tls_path, "server.crt"))
495
+
496
+ # host and port
497
+ if not isinstance(online_config.host, str) or not re.match(Const.ipv4_pattern, online_config.host):
498
+ raise Exception(f"host: {online_config.host} is invalid.")
499
+ if not isinstance(online_config.port, int) or not (0 < online_config.port <= 65535):
500
+ raise Exception(f"port: {online_config.port} is invalid, port range 0-65535.")
501
+
502
+
441
503
  def run_ut_command(args):
504
+ if args.config_path:
505
+ config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
506
+ FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
507
+ checked_config_path = config_path_checker.common_check()
508
+ _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
509
+ checker_config = CheckerConfig(task_config)
510
+ else:
511
+ checker_config = CheckerConfig()
512
+
513
+ if not checker_config.is_online and not args.api_info_file:
514
+ logger.error("Please provide api_info_file for offline run ut.")
515
+ raise Exception("Please provide api_info_file for offline run ut.")
516
+
442
517
  if not is_gpu:
443
518
  torch.npu.set_compile_mode(jit_compile=args.jit_compile)
519
+ if args.jit_compile:
520
+ torch.npu.config.allow_internal_format = True
521
+ else:
522
+ torch.npu.config.allow_internal_format = False
444
523
  used_device = current_device + ":" + str(args.device_id[0])
445
524
  try:
446
525
  if is_gpu:
@@ -459,13 +538,15 @@ def run_ut_command(args):
459
538
  ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
460
539
  checked_api_info = api_info_file_checker.common_check()
461
540
  forward_content, backward_content, real_data_path = parse_json_info_forward_backward(checked_api_info)
541
+ if real_data_path:
542
+ dump_path = os.path.dirname(checked_api_info)
543
+ real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA)
462
544
  if args.filter_api:
463
545
  logger.info("Start filtering the api in the api_info_file.")
464
546
  forward_content = preprocess_forward_content(forward_content)
465
547
  logger.info("Finish filtering the api in the api_info_file.")
466
548
 
467
- out_path = os.path.realpath(args.out_path) if args.out_path else "./"
468
- check_path_before_create(out_path)
549
+ out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
469
550
  create_directory(out_path)
470
551
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
471
552
  out_path = out_path_checker.common_check()
@@ -476,43 +557,31 @@ def run_ut_command(args):
476
557
  if args.result_csv_path:
477
558
  result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result')
478
559
  details_csv_path = get_validated_details_csv_path(result_csv_path)
479
- white_list = msCheckerConfig.white_list
480
- black_list = msCheckerConfig.black_list
481
- error_data_path = msCheckerConfig.error_data_path
482
- is_online = msCheckerConfig.is_online
483
- nfs_path = msCheckerConfig.nfs_path
484
- host = msCheckerConfig.host
485
- port = msCheckerConfig.port
486
- rank_list = msCheckerConfig.rank_list
487
- tls_path = msCheckerConfig.tls_path
488
- if args.config_path:
489
- config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE,
490
- FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
491
- checked_config_path = config_path_checker.common_check()
492
- _, task_config = parse_json_config(checked_config_path, Const.RUN_UT)
493
- white_list = task_config.white_list
494
- black_list = task_config.black_list
495
- error_data_path = task_config.error_data_path
496
- is_online = task_config.is_online
497
- nfs_path = task_config.nfs_path
498
- host = task_config.host
499
- port = task_config.port
500
- rank_list = task_config.rank_list
501
- tls_path = task_config.tls_path
502
560
 
561
+ error_data_path = checker_config.error_data_path
503
562
  if save_error_data:
504
563
  if args.result_csv_path:
505
564
  time_info = result_csv_path.split('.')[0].split('_')[-1]
506
565
  global UT_ERROR_DATA_DIR
507
566
  UT_ERROR_DATA_DIR = 'ut_error_data' + time_info
508
567
  error_data_path = initialize_save_error_data(error_data_path)
509
- online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path)
510
- run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data,
511
- args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path,
512
- online_config)
568
+ online_config = checker_config.get_online_config()
569
+ checked_online_config(online_config)
570
+ config_params = {
571
+ 'forward_content': forward_content,
572
+ 'backward_content': backward_content,
573
+ 'result_csv_path': result_csv_path,
574
+ 'details_csv_path': details_csv_path,
575
+ 'save_error_data': save_error_data,
576
+ 'is_continue_run_ut': args.result_csv_path,
577
+ 'real_data_path': real_data_path,
578
+ 'error_data_path': error_data_path
579
+ }
580
+ run_ut_config = checker_config.get_run_ut_config(**config_params)
513
581
  run_ut(run_ut_config)
514
582
 
515
583
 
516
584
  if __name__ == '__main__':
585
+ seed_all()
517
586
  _run_ut()
518
587
  logger.info("UT task completed.")
@@ -16,6 +16,7 @@
16
16
  # limitations under the License.
17
17
 
18
18
  import os
19
+ from collections import namedtuple
19
20
  import re
20
21
  import torch
21
22
 
@@ -23,8 +24,10 @@ try:
23
24
  import torch_npu
24
25
  except ImportError:
25
26
  current_device = "cuda"
27
+ from torch.cuda.amp import autocast
26
28
  else:
27
29
  current_device = "npu"
30
+ from torch_npu.npu.amp import autocast
28
31
 
29
32
  from msprobe.core.common.const import FileCheckConst, Const, CompareConst
30
33
  from msprobe.core.common.file_utils import FileChecker
@@ -47,11 +50,17 @@ PRECISION_MAPPING = {
47
50
  }
48
51
 
49
52
 
53
+ CpuParams = namedtuple("CpuArgs", ["cpu_args", "cpu_kwargs", "autocast_dtype", "is_autocast"])
54
+ ExecParams = namedtuple("ExecParams", ["api_type", "api_name", "device", "args", "kwargs",
55
+ "is_autocast", "autocast_dtype"])
56
+
57
+
50
58
  class BackwardMessage:
51
59
  MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
52
60
  UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
53
61
  "skip backward."
54
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
62
+ NO_BACKWARD_RESULT_MESSAGE = "This API does not have backward input data, skip backward."
63
+ UNSUPPORT_API_MESSAGE = "This API does not support backward ut, skip backward."
55
64
 
56
65
 
57
66
  class UtDataInfo:
@@ -91,7 +100,15 @@ def get_validated_details_csv_path(validated_result_csv_path):
91
100
  return validated_details_csv_path
92
101
 
93
102
 
94
- def exec_api(api_type, api_name, device, args, kwargs):
103
+ def exec_api(exec_params):
104
+ api_type = exec_params.api_type
105
+ api_name = exec_params.api_name
106
+ device = exec_params.device
107
+ args = exec_params.args
108
+ kwargs = exec_params.kwargs
109
+ is_autocast = exec_params.is_autocast
110
+ autocast_dtype = exec_params.autocast_dtype
111
+
95
112
  if api_type == "Functional":
96
113
  torch_api = FunctionalOPTemplate(api_name, str, False)
97
114
  if api_type == "Tensor":
@@ -102,7 +119,11 @@ def exec_api(api_type, api_name, device, args, kwargs):
102
119
  torch_api = AtenOPTemplate(api_name, None, False)
103
120
  if api_type == "NPU":
104
121
  torch_api = NpuOPTemplate(api_name, None, False, device)
105
- out = torch_api.forward(*args, **kwargs)
122
+ if is_autocast:
123
+ with autocast(dtype=autocast_dtype):
124
+ out = torch_api.forward(*args, **kwargs)
125
+ else:
126
+ out = torch_api.forward(*args, **kwargs)
106
127
  return out
107
128
 
108
129
 
@@ -186,28 +207,48 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
186
207
  logger.error("The depth of arg_in is too large, please check the arg_in.")
187
208
  raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
188
209
  if isinstance(arg_in, (list, tuple)):
189
- return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for arg in arg_in))
210
+ return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for
211
+ arg in arg_in))
190
212
  elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
191
213
  return set([arg_in.dtype])
192
214
  elif isinstance(arg_in, dict) and check_kwargs:
193
- return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for v in arg_in.values()))
215
+ return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for
216
+ v in arg_in.values()))
194
217
  return set()
195
218
 
196
219
  raise_dtype = None
220
+ autocast_dtype = None
221
+ is_autocast = False
197
222
  need_raise_dtypes = recursive_find_dtypes(input_args)
198
223
  need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
199
224
  if len(need_raise_dtypes) == 1:
200
- raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
225
+ origin_dtype = need_raise_dtypes.pop()
226
+ raise_dtype = PRECISION_MAPPING.get(origin_dtype, torch.float32)
227
+ autocast_dtype = origin_dtype
228
+
201
229
  elif len(need_raise_dtypes) >= 2:
202
230
  raise_dtype = torch.float32
231
+ need_raise_dtypes.discard(torch.float32)
232
+ autocast_dtype = need_raise_dtypes.pop()
233
+ is_autocast = True
203
234
 
204
235
  raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
205
236
  is_detach = api_name not in not_detach_set
206
237
  cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
207
- cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
208
- return cpu_args, cpu_kwargs
238
+ cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for
239
+ key, value in input_kwargs.items()}
240
+ cpu_params = CpuParams(cpu_args, cpu_kwargs, autocast_dtype, is_autocast)
241
+ return cpu_params
209
242
 
210
243
 
211
244
  def record_skip_info(api_full_name, compare, compare_alg_results):
212
245
  result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
213
246
  compare.record_results(result_info)
247
+
248
+
249
+ def is_unsupported_api(api_name, is_overflow_check=False):
250
+ split_name = api_name.split(Const.SEP)[0]
251
+ flag = (split_name == Const.DISTRIBUTED) or (is_overflow_check and split_name == Const.NPU)
252
+ if flag:
253
+ logger.info(f"{split_name} api is not supported for run ut. SKIP.")
254
+ return flag