mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -2,13 +2,11 @@ import time
2
2
  import os
3
3
  import math
4
4
 
5
- import numpy as np
6
5
  import torch
7
- import yaml
8
- from msprobe.core.common.utils import CompareException
6
+
7
+ from msprobe.core.common.utils import CompareException, load_yaml
9
8
  from msprobe.core.common.const import Const
10
9
  from msprobe.pytorch.common.log import logger
11
- from msprobe.core.common.file_check import FileOpen
12
10
 
13
11
 
14
12
  current_time = time.strftime("%Y%m%d%H%M%S")
@@ -22,17 +20,15 @@ BINARY_COMPARE_UNSUPPORT_LIST = BENCHMARK_COMPARE_SUPPORT_LIST + API_PRECISION_C
22
20
 
23
21
  cur_path = os.path.dirname(os.path.realpath(__file__))
24
22
  standard_yaml_path = os.path.join(cur_path, "api_precision_standard.yaml")
25
- with FileOpen(standard_yaml_path, 'r') as f:
26
- Apis = yaml.safe_load(f)
27
- AbsoluteStandardApi = Apis.get('AbsoluteThreshStandard')
28
- BinaryStandardApi = Apis.get('BinaryCompareStandard')
29
- ULPStandardApi = Apis.get('ULPStandard')
30
- ThousandthStandardApi = Apis.get('ThousandthStandard')
23
+ apis = load_yaml(standard_yaml_path)
24
+ absolute_standard_api = apis.get('AbsoluteThreshStandard')
25
+ binary_standard_api = apis.get('BinaryCompareStandard')
26
+ ulp_standard_api = apis.get('ULPStandard')
27
+ thousandth_standard_api = apis.get('ThousandthStandard')
31
28
 
32
29
 
33
30
  threshold_yaml_path = os.path.join(cur_path, "api_precision_threshold.yaml")
34
- with FileOpen(threshold_yaml_path, 'r') as f:
35
- apis_threshold = yaml.safe_load(f)
31
+ apis_threshold = load_yaml(threshold_yaml_path)
36
32
 
37
33
 
38
34
  DETAIL_TEST_ROWS = [[
@@ -1,4 +1,10 @@
1
1
  white_list: []
2
+ black_list: []
2
3
  error_data_path: './'
3
4
  precision: 14
4
-
5
+ is_online: False
6
+ nfs_path: ""
7
+ host: ""
8
+ port: -1
9
+ rank_list: [0]
10
+ tls_path: ""
@@ -21,10 +21,11 @@ import torch
21
21
  import numpy
22
22
 
23
23
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
24
- from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, \
25
- get_full_data_path, CompareException
24
+ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
25
+ CompareException
26
+ from msprobe.core.common.file_check import FileChecker
26
27
  from msprobe.pytorch.common.log import logger
27
- from msprobe.core.common.const import Const
28
+ from msprobe.core.common.const import Const, FileCheckConst
28
29
 
29
30
  TORCH_TYPE = ["torch.device", "torch.dtype"]
30
31
  TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
@@ -87,12 +88,13 @@ def gen_real_tensor(data_path, convert_type):
87
88
  convert_type: convert ori_type to dist_type flag.
88
89
  """
89
90
  data_path = os.path.realpath(data_path)
90
- check_file_or_directory_path(data_path)
91
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
92
+ data_path = data_path_checker.common_check()
91
93
  if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
92
94
  error_info = f"The file: {data_path} is not a pt or numpy file."
93
95
  raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
94
96
  if data_path.endswith('.pt'):
95
- data = torch.load(data_path).cpu()
97
+ data = torch.load(data_path, map_location=torch.device('cpu'))
96
98
  else:
97
99
  data_np = numpy.load(data_path)
98
100
  data = torch.from_numpy(data_np)
@@ -255,12 +257,13 @@ def gen_args(args_info, api_name, need_grad=True, convert_type=None, real_data_p
255
257
  return args_result
256
258
 
257
259
 
258
- def gen_kwargs(api_info, convert_type=None, real_data_path=None):
260
+ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None):
259
261
  """
260
262
  Function Description:
261
263
  Based on API basic information, generate input parameters: kwargs, for API forward running
262
264
  Parameter:
263
265
  api_info: API basic information. Dict
266
+ api_name: API name
264
267
  convert_type: convert ori_type to dist_type flag.
265
268
  real_data_path: the root directory for storing real data.
266
269
  """
@@ -268,11 +271,11 @@ def gen_kwargs(api_info, convert_type=None, real_data_path=None):
268
271
  kwargs_params = api_info.get("input_kwargs")
269
272
  for key, value in kwargs_params.items():
270
273
  if isinstance(value, (list, tuple)):
271
- kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path)
274
+ kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path)
272
275
  elif value is None:
273
276
  kwargs_params[key] = None
274
277
  elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"):
275
- kwargs_params[key] = gen_data(value, True, convert_type, real_data_path)
278
+ kwargs_params[key] = gen_data(value, api_name, True, convert_type, real_data_path)
276
279
  elif value.get('type') in TORCH_TYPE:
277
280
  gen_torch_kwargs(kwargs_params, key, value)
278
281
  else:
@@ -285,18 +288,19 @@ def gen_torch_kwargs(kwargs_params, key, value):
285
288
  kwargs_params[key] = eval(value.get('value'))
286
289
 
287
290
 
288
- def gen_list_kwargs(kwargs_item_value, convert_type, real_data_path=None):
291
+ def gen_list_kwargs(kwargs_item_value, api_name, convert_type, real_data_path=None):
289
292
  """
290
293
  Function Description:
291
294
  When kwargs value is list, generate the list of kwargs result
292
295
  Parameter:
293
296
  kwargs_item_value: kwargs value before to generate. List
297
+ api_name: API name
294
298
  convert_type: convert ori_type to dist_type flag.
295
299
  """
296
300
  kwargs_item_result = []
297
301
  for item in kwargs_item_value:
298
302
  if item.get('type') in TENSOR_DATA_LIST:
299
- item_value = gen_data(item, False, convert_type, real_data_path)
303
+ item_value = gen_data(item, api_name, False, convert_type, real_data_path)
300
304
  elif item.get('type') == "torch.Size":
301
305
  item_value = torch.Size(item.get('value'))
302
306
  else:
@@ -319,7 +323,7 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d
319
323
  if convert_type and convert_type not in Const.CONVERT:
320
324
  error_info = f"convert_type params not support {convert_type}."
321
325
  raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info)
322
- kwargs_params = gen_kwargs(api_info, convert_type, real_data_path)
326
+ kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path)
323
327
  if api_info.get("input_args"):
324
328
  args_params = gen_args(api_info.get("input_args"), api_name, need_grad, convert_type, real_data_path)
325
329
  else:
@@ -9,8 +9,9 @@ import threading
9
9
  from collections import namedtuple
10
10
  from itertools import cycle
11
11
  from tqdm import tqdm
12
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, \
13
- get_validated_details_csv_path, preprocess_forward_content
12
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, preprocess_forward_content
13
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path, \
14
+ get_validated_details_csv_path
14
15
  from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
15
16
  from msprobe.pytorch.common import parse_json_info_forward_backward
16
17
  from msprobe.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \
@@ -68,7 +69,7 @@ signal.signal(signal.SIGTERM, signal_handler)
68
69
 
69
70
  ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
70
71
  'save_error_data_flag', 'jit_compile_flag', 'device_id',
71
- 'result_csv_path', 'total_items', 'real_data_path'])
72
+ 'result_csv_path', 'total_items', 'config_path'])
72
73
 
73
74
 
74
75
  def run_parallel_ut(config):
@@ -90,7 +91,7 @@ def run_parallel_ut(config):
90
91
  *(['-j'] if config.jit_compile_flag else []),
91
92
  *(['-save_error_data'] if config.save_error_data_flag else []),
92
93
  '-csv_path', config.result_csv_path,
93
- *(['-real_data_path', config.real_data_path] if config.real_data_path else [])
94
+ *(['-config', config.config_path] if config.config_path else [])
94
95
  ]
95
96
  return cmd
96
97
 
@@ -110,19 +111,14 @@ def run_parallel_ut(config):
110
111
 
111
112
  def update_progress_bar(progress_bar, result_csv_path):
112
113
  while any(process.poll() is None for process in processes):
113
- try:
114
- with open(result_csv_path, 'r') as result_file:
115
- completed_items = len(result_file.readlines()) - 1
116
- progress_bar.update(completed_items - progress_bar.n)
117
- except FileNotFoundError:
118
- logger.warning(f"Result CSV file not found: {result_csv_path}.")
119
- except Exception as e:
120
- logger.error(f"An unexpected error occurred while reading result CSV: {e}")
114
+ with FileOpen(result_csv_path, 'r') as result_file:
115
+ completed_items = len(result_file.readlines()) - 1
116
+ progress_bar.update(completed_items - progress_bar.n)
121
117
  time.sleep(1)
122
118
 
123
119
  for api_info in config.api_files:
124
120
  cmd = create_cmd(api_info, next(device_id_cycle))
125
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1)
121
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1, shell=False)
126
122
  processes.append(process)
127
123
  threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
128
124
 
@@ -175,7 +171,7 @@ def prepare_config(args):
175
171
  out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
176
172
  out_path = out_path_checker.common_check()
177
173
  split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
178
-
174
+ config_path = os.path.realpath(args.config_path) if args.config_path else None
179
175
  result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
180
176
  if not args.result_csv_path:
181
177
  details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
@@ -187,7 +183,7 @@ def prepare_config(args):
187
183
  logger.info(f"UT task details will be saved in {details_csv_path}")
188
184
  return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data,
189
185
  args.jit_compile, args.device_id, result_csv_path,
190
- total_items, args.real_data_path)
186
+ total_items, config_path)
191
187
 
192
188
 
193
189
  def main():
@@ -10,10 +10,14 @@ else:
10
10
  is_gpu = False
11
11
  import torch
12
12
  from tqdm import tqdm
13
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info
14
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents
13
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import generate_device_params, get_api_info
14
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import exec_api
15
+ from msprobe.core.common.utils import get_json_contents
15
16
  from msprobe.core.common.file_check import check_link
16
17
  from msprobe.pytorch.common.log import logger
18
+ from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
19
+ from msprobe.core.common.const import Const
20
+
17
21
 
18
22
  def check_tensor_overflow(x):
19
23
  if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool:
@@ -52,12 +56,12 @@ def check_data_overflow(x):
52
56
 
53
57
  def run_overflow_check(forward_file):
54
58
  logger.info("start UT test")
55
- forward_content = get_json_contents(forward_file)
59
+ forward_content, _, real_data_path = parse_json_info_forward_backward(forward_file)
56
60
  for api_full_name, api_info_dict in tqdm(forward_content.items()):
57
61
  try:
58
- run_torch_api(api_full_name, api_info_dict)
62
+ run_torch_api(api_full_name, api_info_dict, real_data_path)
59
63
  except Exception as err:
60
- api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0]
64
+ _, api_name, _ = api_full_name.split(Const.SEP)
61
65
  if "not implemented for 'Half'" in str(err):
62
66
  logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API "
63
67
  f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.")
@@ -68,11 +72,10 @@ def run_overflow_check(forward_file):
68
72
  logger.error(f"Run {api_full_name} UT Error: %s" % str(err))
69
73
 
70
74
 
71
- def run_torch_api(api_full_name, api_info_dict):
75
+ def run_torch_api(api_full_name, api_info_dict, real_data_path):
72
76
  torch.npu.clear_npu_overflow_flag()
73
- api_type = api_full_name.split(".")[0]
74
- api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0]
75
- args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='')
77
+ api_type, api_name, _ = api_full_name.split(Const.SEP)
78
+ args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path)
76
79
  if not need_grad:
77
80
  logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward."
78
81
  % api_full_name)
@@ -81,6 +84,10 @@ def run_torch_api(api_full_name, api_info_dict):
81
84
  del kwargs["device"]
82
85
  out = exec_api(api_type, api_name, args, kwargs)
83
86
  npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs)
87
+ if out is None and npu_out is None:
88
+ logger.warning("The %s overflow is a normal overflow, out and npu_out is None." % api_full_name)
89
+ return
90
+
84
91
  cpu_overflow = check_data_overflow(out)
85
92
  npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out)
86
93
  if cpu_overflow == npu_overflow: