mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
msprobe/msprobe.py CHANGED
@@ -15,13 +15,16 @@
15
15
 
16
16
  import argparse
17
17
  import sys
18
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
19
- from msprobe.pytorch.parse_tool.cli import parse as cli_parse
20
- from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
21
- from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
22
- _api_precision_compare_command
23
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
24
- _run_overflow_check_command
18
+ import importlib.util
19
+ from msprobe.core.compare.utils import _compare_parser
20
+ from msprobe.core.common.log import logger
21
+ from msprobe.core.compare.compare_cli import compare_cli
22
+ from msprobe.core.common.const import Const
23
+
24
+
25
+ def is_module_available(module_name):
26
+ spec = importlib.util.find_spec(module_name)
27
+ return spec is not None
25
28
 
26
29
 
27
30
  def main():
@@ -31,37 +34,74 @@ def main():
31
34
  "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n"
32
35
  f"For any issue, refer README.md first",
33
36
  )
37
+
34
38
  parser.set_defaults(print_help=parser.print_help)
35
- parser.add_argument('-f', '--framework', required=True, choices=['pytorch'],
39
+ parser.add_argument('-f', '--framework', required=True, choices=[Const.PT_FRAMEWORK, Const.MS_FRAMEWORK],
36
40
  help='Deep learning framework.')
37
41
  subparsers = parser.add_subparsers()
38
42
  subparsers.add_parser('parse')
43
+ compare_cmd_parser = subparsers.add_parser('compare')
39
44
  run_ut_cmd_parser = subparsers.add_parser('run_ut')
40
45
  multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
41
46
  api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
42
47
  run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
43
- _run_ut_parser(run_ut_cmd_parser)
44
- _run_ut_parser(multi_run_ut_cmd_parser)
45
- multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
46
- help='Number of splits for parallel processing. Range: 1-64')
47
- _api_precision_compare_parser(api_precision_compare_cmd_parser)
48
- _run_overflow_check_parser(run_overflow_check_cmd_parser)
48
+ _compare_parser(compare_cmd_parser)
49
+ is_torch_available=is_module_available("torch")
50
+ is_mindspore_available = is_module_available("mindspore")
51
+ if is_torch_available:
52
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
53
+ from msprobe.pytorch.parse_tool.cli import parse as cli_parse
54
+ from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
55
+ from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \
56
+ _api_precision_compare_command
57
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
58
+ _run_overflow_check_command
59
+
60
+ _run_ut_parser(run_ut_cmd_parser)
61
+ _run_ut_parser(multi_run_ut_cmd_parser)
62
+ multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
63
+ help='Number of splits for parallel processing. Range: 1-64')
64
+ _api_precision_compare_parser(api_precision_compare_cmd_parser)
65
+ _run_overflow_check_parser(run_overflow_check_cmd_parser)
66
+ elif is_mindspore_available:
67
+ from msprobe.mindspore.api_accuracy_checker.main import add_api_accuracy_checker_argument
68
+ add_api_accuracy_checker_argument(run_ut_cmd_parser)
69
+
49
70
  if len(sys.argv) == 1:
50
71
  parser.print_help()
51
72
  sys.exit(0)
52
73
  args = parser.parse_args(sys.argv[1:])
53
- if sys.argv[3] == "run_ut":
54
- run_ut_command(args)
55
- elif sys.argv[3] == "parse":
56
- cli_parse()
57
- elif sys.argv[3] == "multi_run_ut":
58
- config = prepare_config(args)
59
- run_parallel_ut(config)
60
- elif sys.argv[3] == "api_precision_compare":
61
- _api_precision_compare_command(args)
62
- elif sys.argv[3] == "run_overflow_check":
63
- _run_overflow_check_command(args)
64
-
74
+ if sys.argv[2] == Const.PT_FRAMEWORK:
75
+ if not is_torch_available:
76
+ logger.error("PyTorch does not exist, please install PyTorch library")
77
+ raise Exception("PyTorch does not exist, please install PyTorch library")
78
+ if sys.argv[3] == "run_ut":
79
+ run_ut_command(args)
80
+ elif sys.argv[3] == "parse":
81
+ cli_parse()
82
+ elif sys.argv[3] == "multi_run_ut":
83
+ config = prepare_config(args)
84
+ run_parallel_ut(config)
85
+ elif sys.argv[3] == "api_precision_compare":
86
+ _api_precision_compare_command(args)
87
+ elif sys.argv[3] == "run_overflow_check":
88
+ _run_overflow_check_command(args)
89
+ elif sys.argv[3] == "compare":
90
+ if args.cell_mapping is not None or args.api_mapping is not None:
91
+ logger.error("Argument -cm or -am is not supported in PyTorch framework")
92
+ raise Exception("Argument -cm or -am is not supported in PyTorch framework")
93
+ compare_cli(args)
94
+ else:
95
+ if not is_module_available(Const.MS_FRAMEWORK):
96
+ logger.error("MindSpore does not exist, please install MindSpore library")
97
+ raise Exception("MindSpore does not exist, please install MindSpore library")
98
+ if sys.argv[3] == "compare":
99
+ if isinstance(args.api_mapping, str):
100
+ logger.warning("User defined mapping tables are not supported in the current version")
101
+ compare_cli(args)
102
+ elif sys.argv[3] == "run_ut":
103
+ from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
104
+ api_checker_main(args)
65
105
 
66
106
  if __name__ == "__main__":
67
107
  main()
@@ -1,4 +1,4 @@
1
1
  from .debugger.precision_debugger import PrecisionDebugger
2
2
  from .common.utils import seed_all
3
- from .compare.acc_compare import compare
4
3
  from .compare.distributed_compare import compare_distributed
4
+ from .compare.pt_compare import compare
@@ -1,17 +1,14 @@
1
1
  import os
2
2
  import yaml
3
- from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path
4
- from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps
5
- from msprobe.core.common.file_check import FileOpen
6
-
7
- WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps)
3
+ from msprobe.core.common.utils import check_file_or_directory_path
4
+ from msprobe.core.common.utils import load_yaml
5
+ from msprobe.pytorch.pt_config import RunUTConfig
8
6
 
9
7
 
10
8
  class Config:
11
9
  def __init__(self, yaml_file):
12
10
  check_file_or_directory_path(yaml_file, False)
13
- with FileOpen(yaml_file, 'r') as file:
14
- config = yaml.safe_load(file)
11
+ config = load_yaml(yaml_file)
15
12
  self.config = {key: self.validate(key, value) for key, value in config.items()}
16
13
 
17
14
  def __getattr__(self, item):
@@ -24,8 +21,15 @@ class Config:
24
21
  def validate(key, value):
25
22
  validators = {
26
23
  'white_list': list,
24
+ 'black_list': list,
27
25
  'error_data_path': str,
28
- 'precision': int
26
+ 'precision': int,
27
+ 'is_online': bool,
28
+ 'nfs_path': str,
29
+ 'host': str,
30
+ 'port': int,
31
+ 'rank_list': list,
32
+ 'tls_path': str
29
33
  }
30
34
  if key not in validators:
31
35
  raise ValueError(f"{key} must be one of {validators.keys()}")
@@ -34,14 +38,15 @@ class Config:
34
38
  if key == 'precision' and value < 0:
35
39
  raise ValueError("precision must be greater than 0")
36
40
  if key == 'white_list':
37
- if not isinstance(value, list):
38
- raise ValueError("white_list must be a list type")
39
- if not all(isinstance(i, str) for i in value):
40
- raise ValueError("All elements in white_list must be of str type")
41
- invalid_api = [i for i in value if i not in WrapApi]
42
- if invalid_api:
43
- raise ValueError(
44
- f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list")
41
+ RunUTConfig.check_filter_list_config(key, value)
42
+ if key == 'black_list':
43
+ RunUTConfig.check_filter_list_config(key, value)
44
+ if key == 'error_data_path':
45
+ RunUTConfig.check_error_data_path_config(value)
46
+ if key == 'nfs_path':
47
+ RunUTConfig.check_nfs_path_config(value)
48
+ if key == 'tls_path':
49
+ RunUTConfig.check_tls_path_config(value)
45
50
  return value
46
51
 
47
52
 
@@ -14,10 +14,8 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
  """
17
- import json
18
17
  import os
19
18
  import re
20
- import csv
21
19
 
22
20
  import torch
23
21
 
@@ -38,12 +36,6 @@ class DumpException(CompareException):
38
36
  pass
39
37
 
40
38
 
41
- def write_csv(data, filepath):
42
- with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
43
- writer = csv.writer(f)
44
- writer.writerows(data)
45
-
46
-
47
39
  def check_object_type(check_object, allow_type):
48
40
  """
49
41
  Function Description:
@@ -59,58 +51,6 @@ def check_object_type(check_object, allow_type):
59
51
  raise CompareException(CompareException.INVALID_DATA_ERROR)
60
52
 
61
53
 
62
- def check_file_or_directory_path(path, isdir=False):
63
- """
64
- Function Description:
65
- check whether the path is valid
66
- Parameter:
67
- path: the path to check
68
- isdir: the path is dir or file
69
- Exception Description:
70
- when invalid data throw exception
71
- """
72
- if isdir:
73
- if not os.path.exists(path):
74
- logger.error('The path {} is not exist.'.format(path))
75
- raise CompareException(CompareException.INVALID_PATH_ERROR)
76
-
77
- if not os.path.isdir(path):
78
- logger.error('The path {} is not a directory.'.format(path))
79
- raise CompareException(CompareException.INVALID_PATH_ERROR)
80
-
81
- if not os.access(path, os.W_OK):
82
- logger.error(
83
- 'The path {} does not have permission to write. Please check the path permission'.format(path))
84
- raise CompareException(CompareException.INVALID_PATH_ERROR)
85
- else:
86
- if not os.path.isfile(path):
87
- logger.error('{} is an invalid file or non-exist.'.format(path))
88
- raise CompareException(CompareException.INVALID_PATH_ERROR)
89
-
90
- if not os.access(path, os.R_OK):
91
- logger.error(
92
- 'The path {} does not have permission to read. Please check the path permission'.format(path))
93
- raise CompareException(CompareException.INVALID_PATH_ERROR)
94
-
95
-
96
- def get_json_contents(file_path):
97
- ops = get_file_content_bytes(file_path)
98
- try:
99
- json_obj = json.loads(ops)
100
- except ValueError as error:
101
- logger.error('Failed to load "%s". %s' % (file_path, str(error)))
102
- raise CompareException(CompareException.INVALID_FILE_ERROR) from error
103
- if not isinstance(json_obj, dict):
104
- logger.error('Json file %s, content is not a dictionary!' % file_path)
105
- raise CompareException(CompareException.INVALID_FILE_ERROR)
106
- return json_obj
107
-
108
-
109
- def get_file_content_bytes(file):
110
- with FileOpen(file, 'rb') as file_handle:
111
- return file_handle.read()
112
-
113
-
114
54
  class SoftlinkCheckException(Exception):
115
55
  pass
116
56
 
@@ -166,6 +106,7 @@ def initialize_save_path(save_path, dir_name):
166
106
  os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
167
107
  data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
168
108
  data_path_checker.common_check()
109
+ return data_path
169
110
 
170
111
 
171
112
  def write_pt(file_path, tensor):
@@ -6,9 +6,6 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAM
6
6
  from msprobe.core.common.const import CompareConst
7
7
 
8
8
 
9
- DEFAULT_THRESHOLD = 1
10
-
11
-
12
9
  #cos
13
10
  def cosine_sim(bench_output, device_output):
14
11
  msg = ""
@@ -197,8 +194,8 @@ def check_norm_value(normal_value_mask, rel_err, rtol):
197
194
 
198
195
  def get_ulp_err(bench_output, device_output, dtype):
199
196
  parameters = ULP_PARAMETERS.get(dtype)
200
- min_eb = parameters.get('min_eb', DEFAULT_THRESHOLD)[0]
201
- exponent_num = parameters.get('exponent_num', DEFAULT_THRESHOLD)[0]
197
+ min_eb = parameters.get('min_eb')[0]
198
+ exponent_num = parameters.get('exponent_num')[0]
202
199
  abs_bench = np.abs(bench_output)
203
200
  eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench)))
204
201
  eb = np.maximum(eb, min_eb)
@@ -7,19 +7,19 @@ from collections import namedtuple
7
7
  import torch
8
8
  import pandas as pd
9
9
 
10
- from msprobe.pytorch.api_accuracy_checker.common.utils import write_csv
10
+ from msprobe.core.common.utils import write_csv
11
11
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
12
12
  from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
13
13
  API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
14
- ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, ThousandthStandardApi, \
14
+ ApiPrecisionCompareColumn, absolute_standard_api, binary_standard_api, ulp_standard_api, thousandth_standard_api, \
15
15
  BINARY_COMPARE_UNSUPPORT_LIST, ULP_COMPARE_SUPPORT_LIST, convert_str_to_float, CompareMessage, is_inf_or_nan, \
16
16
  check_inf_or_nan
17
17
  from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
18
- from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
18
+ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path
19
19
  from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
20
20
  from msprobe.pytorch.common.log import logger
21
21
  from msprobe.core.common.utils import CompareException
22
- from msprobe.core.common.const import CompareConst, FileCheckConst
22
+ from msprobe.core.common.const import CompareConst, FileCheckConst, Const
23
23
 
24
24
  CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
25
25
  BenchmarkInf_Nan_Consistency = namedtuple('BenchmarkInf_Nan_Consistency', ['small_value_inf_nan_consistency',
@@ -289,15 +289,38 @@ def api_precision_compare(config):
289
289
  change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
290
290
 
291
291
 
292
+ def online_api_precision_compare(online_config):
293
+ rank = online_config.rank
294
+ result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
295
+ details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace("_rank*.csv", f"_rank{rank}.csv")
296
+ detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
297
+ result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
298
+ if not os.path.exists(result_csv_path):
299
+ write_csv(result_csv_title, result_csv_path)
300
+ if not os.path.exists(details_csv_path):
301
+ write_csv(detail_csv_title, details_csv_path)
302
+ config = CompareConfig("", "", result_csv_path, details_csv_path)
303
+ try:
304
+ npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
305
+ check_csv_columns(npu_data.columns, "npu_csv")
306
+ check_csv_columns(gpu_data.columns, "gpu_csv")
307
+ analyse_csv(npu_data, gpu_data, config)
308
+ except Exception as err:
309
+ logger.error(f"Online api precision compare Error: {str(err)}")
310
+ change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
311
+ change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
312
+
313
+
292
314
  def analyse_csv(npu_data, gpu_data, config):
293
315
  forward_status, backward_status = [], []
294
- last_api_name, last_api_dtype = None, None
316
+ last_api_name, last_api_dtype, last_api_full_name = None, None, None
295
317
  for _, row_npu in npu_data.iterrows():
296
318
  message = ''
297
319
  compare_column = ApiPrecisionOutputColumn()
298
320
  full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME]
299
321
  row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status]
300
- _, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".")
322
+ api_type, api_name, api_nums, direction_status, _, _ = full_api_name_with_direction_status.split(Const.SEP)
323
+ api_full_name = Const.SEP.join([api_type, api_name, api_nums])
301
324
  if row_gpu.empty:
302
325
  logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.')
303
326
  continue
@@ -315,14 +338,14 @@ def analyse_csv(npu_data, gpu_data, config):
315
338
  write_detail_csv(compare_column.to_column_value(), config.details_csv_path)
316
339
  else:
317
340
  compare_column.api_name = full_api_name_with_direction_status
318
- if api_name in ThousandthStandardApi:
341
+ if api_name in thousandth_standard_api:
319
342
  new_status = record_thousandth_threshold_result(compare_column, row_npu)
320
343
  elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or \
321
- api_name in BinaryStandardApi:
344
+ api_name in binary_standard_api:
322
345
  new_status = record_binary_consistency_result(api_name, compare_column, row_npu)
323
- elif api_name in AbsoluteStandardApi:
346
+ elif api_name in absolute_standard_api:
324
347
  new_status = record_absolute_threshold_result(compare_column, row_npu)
325
- elif api_name in ULPStandardApi and \
348
+ elif api_name in ulp_standard_api and \
326
349
  row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in ULP_COMPARE_SUPPORT_LIST:
327
350
  us = ULPStandard(full_api_name_with_direction_status, row_npu, row_gpu)
328
351
  new_status = record_ulp_compare_result(compare_column, us)
@@ -335,6 +358,7 @@ def analyse_csv(npu_data, gpu_data, config):
335
358
  if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
336
359
  message = unsupported_message
337
360
  write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
361
+ print_test_success(api_full_name, "skip", "skip")
338
362
  forward_status, backward_status = [], []
339
363
  message = ''
340
364
  else:
@@ -342,11 +366,13 @@ def analyse_csv(npu_data, gpu_data, config):
342
366
  backward_result = get_api_checker_result(backward_status)
343
367
  message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
344
368
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
369
+ print_test_success(api_full_name, forward_result, backward_result)
345
370
  forward_status, backward_status = [], []
346
371
  message = ''
347
372
 
348
373
  is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST
349
374
  last_api_name = api_name
375
+ last_api_full_name = api_full_name
350
376
 
351
377
  last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE]
352
378
  if not is_supported:
@@ -363,11 +389,21 @@ def analyse_csv(npu_data, gpu_data, config):
363
389
  if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST:
364
390
  message = unsupported_message
365
391
  write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path)
392
+ print_test_success(last_api_full_name, "skip", "skip")
366
393
  else:
367
394
  forward_result = get_api_checker_result(forward_status)
368
395
  backward_result = get_api_checker_result(backward_status)
369
396
  message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else ""
370
397
  write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path)
398
+ print_test_success(last_api_full_name, forward_result, backward_result)
399
+
400
+
401
+ def print_test_success(api_full_name, forward_result, backward_result):
402
+ is_fwd_success = (forward_result == CompareConst.PASS)
403
+ is_bwd_success = (backward_result == CompareConst.PASS or backward_result == CompareConst.SPACE)
404
+ logger.info(f"running api_full_name {api_full_name} compare, "
405
+ f"is_fwd_success: {is_fwd_success}, "
406
+ f"is_bwd_success: {is_bwd_success}")
371
407
 
372
408
 
373
409
  def check_error_rate(npu_error_rate):
@@ -1,27 +1,28 @@
1
1
  # 进行比对及结果展示
2
2
  import os
3
3
  from collections import namedtuple
4
- import torch
4
+
5
5
  import numpy as np
6
- from msprobe.pytorch.common.log import logger
7
- from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv
8
- from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
9
- DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \
10
- ULPStandardApi, ThousandthStandardApi, apis_threshold
11
- from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
6
+ from msprobe.core.common.utils import write_csv, get_json_contents, CompareException
7
+ import torch
8
+ from msprobe.core.common.const import Const, CompareConst
12
9
  from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \
13
10
  get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \
14
11
  get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
15
12
  check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
16
13
  from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
17
- from msprobe.core.common.const import Const, CompareConst
14
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
15
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
16
+ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, absolute_standard_api, binary_standard_api, \
17
+ ulp_standard_api, thousandth_standard_api, apis_threshold
18
+ from msprobe.pytorch.common.log import logger
18
19
 
19
20
 
20
21
  ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
21
22
  'fwd_compare_alg_results', 'bwd_compare_alg_results', 'rank'])
22
23
 
23
24
 
24
- INDEX_TEST_RESULT__GROUP = 3
25
+ INDEX_TEST_RESULT_GROUP = 3
25
26
  INDEX_FIRST_GROUP = 0
26
27
  INDEX_MESSAGE = -1
27
28
 
@@ -33,20 +34,34 @@ class Comparator:
33
34
  COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
34
35
  COLUMN_STACK_INFO = "Traceback callstack info"
35
36
 
36
- def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
37
- self.save_path = result_csv_path
38
- self.detail_save_path = details_csv_path
39
- if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
37
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None):
38
+ self.save_path_str = result_csv_path
39
+ self.detail_save_path_str = details_csv_path
40
+ self.save_path_list = [result_csv_path]
41
+ self.detail_save_path_list = [details_csv_path]
42
+
43
+ if config and config.online_config.is_online:
44
+ self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
45
+ self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
46
+ self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
47
+ self.detail_save_path_list = \
48
+ [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
49
+
50
+ if not is_continue_run_ut:
40
51
  self.write_csv_title()
41
52
  if stack_info_json_path:
42
53
  self.stack_info = get_json_contents(stack_info_json_path)
43
54
  else:
44
55
  self.stack_info = None
45
56
 
57
+ @staticmethod
58
+ def get_path_from_rank(rank, path_list, path_pattern):
59
+ return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
60
+
46
61
  @staticmethod
47
62
  def print_pretest_result():
48
63
  logger.info("Successfully completed run_ut/multi_run_ut.")
49
-
64
+
50
65
  @staticmethod
51
66
  def _compare_dropout(bench_output, device_output):
52
67
  tensor_num = bench_output.numel()
@@ -75,7 +90,7 @@ class Comparator:
75
90
  error_rate = float(error_nums / bench_output.size)
76
91
  result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR
77
92
  return error_rate, result, ""
78
-
93
+
79
94
  @staticmethod
80
95
  def _get_absolute_threshold_attribute(api_name, dtype):
81
96
  small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value')
@@ -83,35 +98,18 @@ class Comparator:
83
98
  rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
84
99
  return small_value_threshold, small_value_atol, rtol
85
100
 
86
- def write_csv_title(self):
87
- summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
88
- self.COLUMN_BACKWARD_SUCCESS, "Message"]]
89
- if not os.path.exists(self.save_path):
90
- write_csv(summary_test_rows, self.save_path)
91
- if not os.path.exists(self.detail_save_path):
92
- write_csv(DETAIL_TEST_ROWS, self.detail_save_path)
93
-
94
- def write_summary_csv(self, test_result):
95
- test_rows = []
96
- if self.stack_info:
97
- test_rows[0].append(self.COLUMN_STACK_INFO)
98
-
99
- name = test_result[0]
100
- df_row = list(test_result[:INDEX_TEST_RESULT__GROUP])
101
- if test_result[1] == "SKIP":
102
- df_row.append(test_result[INDEX_TEST_RESULT__GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
103
- if self.stack_info:
104
- stack_info = "\n".join(self.stack_info[name])
105
- df_row.append(stack_info)
106
- test_rows.append(df_row)
107
- write_csv(test_rows, self.save_path)
108
-
109
- def write_detail_csv(self, test_result):
101
+ @staticmethod
102
+ def _get_run_ut_detail(test_result):
103
+ """get run_ut detail before write to csv, called by online run_ut"""
110
104
  test_rows = []
105
+ try:
106
+ subject_prefix = test_result[0]
107
+ fwd_result = test_result[3]
108
+ bwd_result = test_result[4]
109
+ except IndexError as e:
110
+ logger.error("List index out of bounds when writing detail CSV.")
111
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
111
112
 
112
- subject_prefix = test_result[0]
113
- fwd_result = test_result[3]
114
- bwd_result = test_result[4]
115
113
  if isinstance(fwd_result, list):
116
114
  for i, test_subject in enumerate(fwd_result):
117
115
  subject = subject_prefix + ".forward.output." + str(i)
@@ -124,14 +122,49 @@ class Comparator:
124
122
  test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision)
125
123
  if isinstance(item, float) else item for item in test_subject]
126
124
  test_rows.append([subject] + list(test_subject))
125
+ return test_rows
127
126
 
128
- write_csv(test_rows, self.detail_save_path)
127
+ def write_csv_title(self):
128
+ summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
129
+ self.COLUMN_BACKWARD_SUCCESS, "Message"]]
130
+ for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list):
131
+ if not os.path.exists(save_path):
132
+ write_csv(summary_test_rows, save_path)
133
+ if not os.path.exists(detail_save_path):
134
+ write_csv(DETAIL_TEST_ROWS, detail_save_path)
135
+
136
+ def write_summary_csv(self, test_result):
137
+ test_rows = []
138
+ try:
139
+ name = test_result[0]
140
+ df_row = list(test_result[:INDEX_TEST_RESULT_GROUP])
141
+ if test_result[1] == "SKIP":
142
+ df_row.append(test_result[INDEX_TEST_RESULT_GROUP][INDEX_FIRST_GROUP][INDEX_MESSAGE])
143
+ if self.stack_info:
144
+ stack_info = "\n".join(self.stack_info[name])
145
+ df_row.append(stack_info)
146
+ test_rows.append(df_row)
147
+ save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str)
148
+ except IndexError as e:
149
+ logger.error("List index out of bounds when writing summary CSV.")
150
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR, "list index out of bounds") from e
151
+ write_csv(test_rows, save_path)
152
+
153
+ def write_detail_csv(self, test_result):
154
+ test_rows = self._get_run_ut_detail(test_result)
155
+ detail_save_path = self.get_path_from_rank(test_result[-1],
156
+ self.detail_save_path_list,
157
+ self.detail_save_path_str)
158
+ write_csv(test_rows, detail_save_path)
129
159
 
130
160
  def record_results(self, args):
131
161
  self.write_summary_csv(args)
132
162
  self.write_detail_csv(args)
133
163
 
134
- def compare_output(self, full_api_name, data_info):
164
+ def compare_output(self, full_api_name, data_info, is_online=False):
165
+ """Get compare result and write to result and detail csv.
166
+ is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
167
+ """
135
168
  _, api_name, _ = full_api_name.split(Const.SEP)
136
169
  bench_output, device_output = data_info.bench_output, data_info.device_output
137
170
  bench_grad, device_grad = data_info.bench_grad, data_info.device_grad
@@ -160,6 +193,9 @@ class Comparator:
160
193
  fwd_compare_alg_results,
161
194
  bwd_compare_alg_results,
162
195
  data_info.rank)
196
+ if is_online:
197
+ # get run_ut compare detail
198
+ return self._get_run_ut_detail(result_info)
163
199
  self.record_results(result_info)
164
200
  return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
165
201
  or bwd_success_status == CompareConst.SPACE
@@ -261,15 +297,15 @@ class Comparator:
261
297
  abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype)
262
298
  abs_err = get_abs_err(bench_output, device_output)
263
299
  rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps)
264
- if api_name in ThousandthStandardApi:
300
+ if api_name in thousandth_standard_api:
265
301
  thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD)
266
302
  compare_column.rel_err_thousandth = thousand_res
267
303
  if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST:
268
304
  both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output)
269
- if api_name in BinaryStandardApi:
305
+ if api_name in binary_standard_api:
270
306
  err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output)
271
307
  compare_column.error_rate = err_rate
272
- elif api_name in AbsoluteStandardApi:
308
+ elif api_name in absolute_standard_api:
273
309
  small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute(
274
310
  api_name, str(dtype))
275
311
  rel_err = abs_err / abs_bench_with_eps
@@ -279,7 +315,7 @@ class Comparator:
279
315
  dtype, rtol)
280
316
  compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol)
281
317
  compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol)
282
- elif api_name in ULPStandardApi:
318
+ elif api_name in ulp_standard_api:
283
319
  if bench_output.size == 0:
284
320
  compare_column.max_ulp_error = 0
285
321
  compare_column.mean_ulp_error = 0