mindstudio-probe 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/METADATA +5 -1
  2. mindstudio_probe-1.0.3.dist-info/RECORD +272 -0
  3. msprobe/README.md +78 -23
  4. msprobe/__init__.py +1 -0
  5. msprobe/config/README.md +182 -40
  6. msprobe/config/config.json +22 -0
  7. msprobe/core/__init__.py +0 -0
  8. msprobe/{pytorch → core}/advisor/advisor.py +3 -3
  9. msprobe/{pytorch → core}/advisor/advisor_result.py +2 -2
  10. msprobe/core/common/const.py +82 -5
  11. msprobe/core/common/exceptions.py +30 -18
  12. msprobe/core/common/file_check.py +19 -1
  13. msprobe/core/common/log.py +15 -1
  14. msprobe/core/common/utils.py +130 -30
  15. msprobe/core/common_config.py +32 -19
  16. msprobe/core/compare/acc_compare.py +299 -0
  17. msprobe/core/compare/check.py +95 -0
  18. msprobe/core/compare/compare_cli.py +49 -0
  19. msprobe/core/compare/highlight.py +222 -0
  20. msprobe/core/compare/multiprocessing_compute.py +149 -0
  21. msprobe/{pytorch → core}/compare/npy_compare.py +55 -4
  22. msprobe/core/compare/utils.py +429 -0
  23. msprobe/core/data_dump/data_collector.py +39 -35
  24. msprobe/core/data_dump/data_processor/base.py +85 -37
  25. msprobe/core/data_dump/data_processor/factory.py +5 -7
  26. msprobe/core/data_dump/data_processor/mindspore_processor.py +198 -0
  27. msprobe/core/data_dump/data_processor/pytorch_processor.py +94 -51
  28. msprobe/core/data_dump/json_writer.py +11 -11
  29. msprobe/core/grad_probe/__init__.py +0 -0
  30. msprobe/core/grad_probe/constant.py +71 -0
  31. msprobe/core/grad_probe/grad_compare.py +175 -0
  32. msprobe/core/grad_probe/utils.py +52 -0
  33. msprobe/doc/grad_probe/grad_probe.md +207 -0
  34. msprobe/doc/grad_probe/img/image-1.png +0 -0
  35. msprobe/doc/grad_probe/img/image-2.png +0 -0
  36. msprobe/doc/grad_probe/img/image-3.png +0 -0
  37. msprobe/doc/grad_probe/img/image-4.png +0 -0
  38. msprobe/doc/grad_probe/img/image.png +0 -0
  39. msprobe/mindspore/api_accuracy_checker/__init__.py +0 -0
  40. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +246 -0
  41. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -0
  42. msprobe/mindspore/api_accuracy_checker/api_runner.py +152 -0
  43. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +197 -0
  44. msprobe/mindspore/api_accuracy_checker/compute_element.py +224 -0
  45. msprobe/mindspore/api_accuracy_checker/main.py +16 -0
  46. msprobe/mindspore/api_accuracy_checker/type_mapping.py +114 -0
  47. msprobe/mindspore/api_accuracy_checker/utils.py +63 -0
  48. msprobe/mindspore/cell_processor.py +34 -0
  49. msprobe/mindspore/common/const.py +87 -0
  50. msprobe/mindspore/common/log.py +38 -0
  51. msprobe/mindspore/common/utils.py +57 -0
  52. msprobe/mindspore/compare/distributed_compare.py +75 -0
  53. msprobe/mindspore/compare/ms_compare.py +117 -0
  54. msprobe/mindspore/compare/ms_graph_compare.py +317 -0
  55. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -0
  56. msprobe/mindspore/debugger/debugger_config.py +38 -15
  57. msprobe/mindspore/debugger/precision_debugger.py +79 -4
  58. msprobe/mindspore/doc/compare.md +58 -0
  59. msprobe/mindspore/doc/dump.md +158 -6
  60. msprobe/mindspore/dump/dump_tool_factory.py +19 -22
  61. msprobe/mindspore/dump/hook_cell/api_registry.py +104 -0
  62. msprobe/mindspore/dump/hook_cell/hook_cell.py +53 -0
  63. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +925 -0
  64. msprobe/mindspore/dump/hook_cell/wrap_functional.py +91 -0
  65. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +63 -0
  66. msprobe/mindspore/dump/jit_dump.py +56 -0
  67. msprobe/mindspore/dump/kernel_kbyk_dump.py +65 -0
  68. msprobe/mindspore/free_benchmark/__init__.py +0 -0
  69. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -0
  70. msprobe/mindspore/free_benchmark/common/__init__.py +0 -0
  71. msprobe/mindspore/free_benchmark/common/config.py +12 -0
  72. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -0
  73. msprobe/mindspore/free_benchmark/common/utils.py +71 -0
  74. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -0
  75. msprobe/mindspore/free_benchmark/decorator/__init__.py +0 -0
  76. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +42 -0
  77. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -0
  78. msprobe/mindspore/free_benchmark/handler/__init__.py +0 -0
  79. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -0
  80. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -0
  81. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -0
  82. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -0
  83. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -0
  84. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -0
  85. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -0
  86. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +34 -0
  87. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -0
  88. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +27 -0
  89. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -0
  90. msprobe/mindspore/grad_probe/__init__.py +0 -0
  91. msprobe/mindspore/grad_probe/global_context.py +91 -0
  92. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -0
  93. msprobe/mindspore/grad_probe/grad_monitor.py +27 -0
  94. msprobe/mindspore/grad_probe/grad_stat_csv.py +132 -0
  95. msprobe/mindspore/grad_probe/hook.py +92 -0
  96. msprobe/mindspore/grad_probe/utils.py +29 -0
  97. msprobe/mindspore/ms_config.py +63 -15
  98. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +17 -15
  99. msprobe/mindspore/runtime.py +4 -0
  100. msprobe/mindspore/service.py +354 -0
  101. msprobe/mindspore/task_handler_factory.py +7 -4
  102. msprobe/msprobe.py +66 -26
  103. msprobe/pytorch/__init__.py +1 -1
  104. msprobe/pytorch/api_accuracy_checker/common/config.py +21 -16
  105. msprobe/pytorch/api_accuracy_checker/common/utils.py +1 -60
  106. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +2 -5
  107. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +46 -10
  108. msprobe/pytorch/api_accuracy_checker/compare/compare.py +84 -48
  109. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +8 -12
  110. msprobe/pytorch/api_accuracy_checker/config.yaml +7 -1
  111. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +15 -11
  112. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +11 -15
  113. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +16 -9
  114. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +193 -105
  115. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +68 -1
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +202 -0
  118. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +324 -0
  119. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -0
  120. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +218 -0
  121. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -0
  122. msprobe/pytorch/bench_functions/__init__.py +15 -0
  123. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -0
  124. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -0
  125. msprobe/pytorch/bench_functions/fast_gelu.py +55 -0
  126. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -0
  127. msprobe/pytorch/bench_functions/linear.py +12 -0
  128. msprobe/pytorch/bench_functions/matmul_backward.py +48 -0
  129. msprobe/pytorch/bench_functions/npu_fusion_attention.py +421 -0
  130. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  131. msprobe/pytorch/bench_functions/rotary_mul.py +52 -0
  132. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -0
  133. msprobe/pytorch/bench_functions/swiglu.py +55 -0
  134. msprobe/pytorch/common/parse_json.py +3 -1
  135. msprobe/pytorch/common/utils.py +83 -7
  136. msprobe/pytorch/compare/distributed_compare.py +19 -64
  137. msprobe/pytorch/compare/match.py +3 -6
  138. msprobe/pytorch/compare/pt_compare.py +40 -0
  139. msprobe/pytorch/debugger/debugger_config.py +11 -2
  140. msprobe/pytorch/debugger/precision_debugger.py +34 -4
  141. msprobe/pytorch/doc/api_accuracy_checker.md +57 -13
  142. msprobe/pytorch/doc/api_accuracy_checker_online.md +187 -0
  143. msprobe/pytorch/doc/dump.md +73 -20
  144. msprobe/pytorch/doc/ptdbg_ascend_compare.md +75 -11
  145. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +3 -3
  146. msprobe/pytorch/doc/run_overflow_check.md +1 -1
  147. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +151 -0
  148. msprobe/pytorch/free_benchmark/common/constant.py +3 -0
  149. msprobe/pytorch/free_benchmark/common/utils.py +4 -0
  150. msprobe/pytorch/free_benchmark/compare/grad_saver.py +22 -26
  151. msprobe/pytorch/free_benchmark/main.py +7 -4
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +1 -1
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +1 -1
  154. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  155. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +3 -3
  156. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +1 -1
  157. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +1 -1
  158. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +43 -29
  159. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +0 -1
  160. msprobe/pytorch/function_factory.py +75 -0
  161. msprobe/pytorch/functional/dump_module.py +4 -4
  162. msprobe/pytorch/grad_probe/__init__.py +0 -0
  163. msprobe/pytorch/grad_probe/grad_monitor.py +90 -0
  164. msprobe/pytorch/grad_probe/grad_stat_csv.py +129 -0
  165. msprobe/pytorch/hook_module/hook_module.py +14 -3
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  167. msprobe/pytorch/hook_module/utils.py +9 -9
  168. msprobe/pytorch/hook_module/wrap_aten.py +20 -10
  169. msprobe/pytorch/hook_module/wrap_distributed.py +10 -7
  170. msprobe/pytorch/hook_module/wrap_functional.py +4 -7
  171. msprobe/pytorch/hook_module/wrap_npu_custom.py +21 -10
  172. msprobe/pytorch/hook_module/wrap_tensor.py +5 -6
  173. msprobe/pytorch/hook_module/wrap_torch.py +5 -7
  174. msprobe/pytorch/hook_module/wrap_vf.py +6 -8
  175. msprobe/pytorch/module_processer.py +53 -13
  176. msprobe/pytorch/online_dispatch/compare.py +4 -4
  177. msprobe/pytorch/online_dispatch/dispatch.py +39 -41
  178. msprobe/pytorch/online_dispatch/dump_compare.py +17 -47
  179. msprobe/pytorch/online_dispatch/single_compare.py +5 -5
  180. msprobe/pytorch/online_dispatch/utils.py +2 -43
  181. msprobe/pytorch/parse_tool/lib/compare.py +31 -19
  182. msprobe/pytorch/parse_tool/lib/config.py +2 -1
  183. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -4
  184. msprobe/pytorch/parse_tool/lib/utils.py +34 -80
  185. msprobe/pytorch/parse_tool/lib/visualization.py +4 -3
  186. msprobe/pytorch/pt_config.py +100 -6
  187. msprobe/pytorch/service.py +104 -19
  188. mindstudio_probe-1.0.1.dist-info/RECORD +0 -228
  189. msprobe/mindspore/dump/api_kbk_dump.py +0 -55
  190. msprobe/pytorch/compare/acc_compare.py +0 -1024
  191. msprobe/pytorch/compare/highlight.py +0 -100
  192. msprobe/test/core_ut/common/test_utils.py +0 -345
  193. msprobe/test/core_ut/data_dump/test_data_collector.py +0 -47
  194. msprobe/test/core_ut/data_dump/test_json_writer.py +0 -183
  195. msprobe/test/core_ut/data_dump/test_scope.py +0 -151
  196. msprobe/test/core_ut/test_common_config.py +0 -152
  197. msprobe/test/core_ut/test_file_check.py +0 -218
  198. msprobe/test/core_ut/test_log.py +0 -109
  199. msprobe/test/mindspore_ut/test_api_kbk_dump.py +0 -51
  200. msprobe/test/mindspore_ut/test_debugger_config.py +0 -42
  201. msprobe/test/mindspore_ut/test_dump_tool_factory.py +0 -51
  202. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +0 -66
  203. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +0 -63
  204. msprobe/test/mindspore_ut/test_ms_config.py +0 -69
  205. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +0 -51
  206. msprobe/test/mindspore_ut/test_precision_debugger.py +0 -56
  207. msprobe/test/mindspore_ut/test_task_handler_factory.py +0 -58
  208. msprobe/test/pytorch_ut/advisor/test_advisor.py +0 -83
  209. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +0 -108
  210. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +0 -39
  211. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +0 -112
  212. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +0 -77
  213. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +0 -125
  214. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +0 -10
  215. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +0 -43
  216. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +0 -179
  217. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +0 -63
  218. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +0 -99
  219. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +0 -115
  220. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +0 -72
  221. msprobe/test/pytorch_ut/compare/test_acc_compare.py +0 -17
  222. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +0 -105
  223. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +0 -121
  224. msprobe/test/pytorch_ut/free_benchmark/test_main.py +0 -101
  225. msprobe/test/pytorch_ut/functional/test_dump_module.py +0 -15
  226. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +0 -130
  227. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +0 -42
  228. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +0 -65
  229. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +0 -35
  230. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +0 -20
  231. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +0 -35
  232. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +0 -43
  233. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +0 -11
  234. msprobe/test/pytorch_ut/test_pt_config.py +0 -69
  235. msprobe/test/pytorch_ut/test_service.py +0 -59
  236. msprobe/test/resources/advisor.txt +0 -3
  237. msprobe/test/resources/compare_result_20230703104808.csv +0 -9
  238. msprobe/test/resources/compare_result_without_accuracy.csv +0 -9
  239. msprobe/test/resources/config.yaml +0 -3
  240. msprobe/test/resources/npu_test.pkl +0 -8
  241. msprobe/test/run_test.sh +0 -30
  242. msprobe/test/run_ut.py +0 -58
  243. msprobe/test/test_module_processer.py +0 -64
  244. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/LICENSE +0 -0
  245. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/WHEEL +0 -0
  246. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/entry_points.txt +0 -0
  247. {mindstudio_probe-1.0.1.dist-info → mindstudio_probe-1.0.3.dist-info}/top_level.txt +0 -0
  248. /msprobe/{pytorch → core}/advisor/advisor_const.py +0 -0
  249. /msprobe/pytorch/doc/{atat → msprobe}/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md" +0 -0
@@ -1,15 +1,18 @@
1
1
  from functools import wraps
2
+
2
3
  import torch
3
4
  from torch.utils.hooks import BackwardHook
5
+
4
6
  from msprobe.core.common.const import Const
5
7
  from msprobe.core.data_dump.scope import ModuleRangeScope
8
+ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0'
6
9
 
7
10
 
8
11
  class ModuleProcesser:
12
+ module_count = {}
9
13
  module_stack = []
10
14
  api_parent_node = ""
11
15
  module_node = {}
12
- current_module_name = ""
13
16
 
14
17
  def __init__(self, scope):
15
18
  if isinstance(scope, ModuleRangeScope):
@@ -19,15 +22,22 @@ class ModuleProcesser:
19
22
  BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
20
23
  BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
21
24
  BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
22
- self.module_count = {}
23
25
 
24
26
  @staticmethod
25
27
  def filter_tensor_and_tuple(func):
26
28
  @wraps(func)
27
29
  def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
28
- # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入
30
+ # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook
29
31
  # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
30
32
  if not isinstance(args[1], (torch.Tensor, tuple)):
33
+ for item_str in dir(args[1]):
34
+ item = getattr(args[1], item_str)
35
+ # 处理tensor或者只包含tensor的元组
36
+ if isinstance(item, torch.Tensor) or \
37
+ (isinstance(item, tuple) and all(isinstance(x, torch.Tensor) for x in item)):
38
+ args_new = (args[0], item)
39
+ result = func(*args_new, **kwargs)
40
+ setattr(args[1], item_str, result)
31
41
  return args[1]
32
42
  return func(*args, **kwargs)
33
43
 
@@ -55,11 +65,26 @@ class ModuleProcesser:
55
65
  else:
56
66
  return result
57
67
 
68
+ @staticmethod
69
+ def module_count_func(module_name):
70
+ if module_name not in ModuleProcesser.module_count:
71
+ ModuleProcesser.module_count[module_name] = 0
72
+ else:
73
+ ModuleProcesser.module_count[module_name] += 1
74
+ return ModuleProcesser.module_count[module_name]
75
+
76
+ @classmethod
77
+ def reset_module_stats(cls):
78
+ cls.module_count = {}
79
+ cls.module_stack = []
80
+ cls.api_parent_node = ""
81
+ cls.module_node = {}
82
+
58
83
  def node_hook(self, name_prefix, start_or_stop, **kwargs):
59
84
 
60
85
  def pre_hook(module, input, output=None):
61
86
  try:
62
- index = self.module_count_func(name_prefix)
87
+ index = ModuleProcesser.module_count_func(name_prefix)
63
88
  except IndexError as e:
64
89
  index = None
65
90
  pass
@@ -85,14 +110,29 @@ class ModuleProcesser:
85
110
  if self.scope:
86
111
  self.scope.end_module(module.mindstudio_reserved_name)
87
112
 
88
- if Const.START in start_or_stop:
89
- return pre_hook
90
- else:
91
- return end_hook
113
+ def backward_hook(module, input, output=None):
114
+ try:
115
+ index = ModuleProcesser.module_count_func(name_prefix)
116
+ except IndexError as e:
117
+ index = None
118
+ pass
119
+ module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
120
+ forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD)
121
+ ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace(
122
+ Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None
123
+ ModuleProcesser.api_parent_node = None
124
+ if self.scope:
125
+ self.scope.begin_module(full_name)
92
126
 
93
- def module_count_func(self, module_name):
94
- if module_name not in self.module_count:
95
- self.module_count[module_name] = 0
127
+ if torch_version_above_or_equal_2:
128
+ if Const.START in start_or_stop:
129
+ return pre_hook
130
+ else:
131
+ return end_hook
96
132
  else:
97
- self.module_count[module_name] += 1
98
- return self.module_count[module_name]
133
+ if Const.FORWARD in name_prefix and Const.START in start_or_stop:
134
+ return pre_hook
135
+ elif Const.BACKWARD in name_prefix:
136
+ return backward_hook
137
+ else:
138
+ return end_hook
@@ -6,10 +6,9 @@ import json
6
6
  from collections import namedtuple
7
7
  from rich.table import Table
8
8
  from rich.console import Console
9
+ from msprobe.core.common.const import CompareConst, FileCheckConst
10
+ from msprobe.core.common.file_check import FileOpen, change_mode
9
11
  from .single_compare import single_benchmark_compare_wrap
10
- from .utils import DispatchException
11
- from msprobe.core.common.const import CompareConst
12
- from msprobe.core.common.file_check import FileOpen
13
12
  from msprobe.pytorch.common.log import logger
14
13
  from msprobe.core.common.utils import CompareException
15
14
 
@@ -42,6 +41,7 @@ def write_csv(data, filepath):
42
41
  with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
43
42
  writer = csv.writer(f)
44
43
  writer.writerows(data)
44
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
45
45
 
46
46
 
47
47
  class Saver:
@@ -228,7 +228,7 @@ class Comparator:
228
228
  else:
229
229
  is_bwd_success, bwd_compare_alg_results = True, None
230
230
  if is_bwd_success and bwd_compare_alg_results is None:
231
- self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results,
231
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NAN, fwd_compare_alg_results,
232
232
  bwd_compare_alg_results))
233
233
  else:
234
234
  self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
@@ -4,7 +4,6 @@ import json
4
4
  from pathlib import Path
5
5
  from multiprocessing import Manager, Pool
6
6
 
7
- import yaml
8
7
  import torch
9
8
 
10
9
  from torch.utils._python_dispatch import TorchDispatchMode
@@ -16,14 +15,14 @@ except ImportError:
16
15
  else:
17
16
  is_npu = True
18
17
 
18
+ from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create, load_yaml
19
+ from msprobe.core.common.const import Const, CompareConst
20
+ from msprobe.pytorch.common.log import logger
19
21
  from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
20
22
  DispatchRunParam, DisPatchDataInfo
21
- from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \
22
- DispatchException
23
+ from .utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, COMPARE_LOGO
23
24
  from .compare import Comparator
24
- from msprobe.core.common.file_check import FileOpen
25
- from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
26
- from msprobe.core.common.const import Const, CompareConst
25
+
27
26
 
28
27
  current_time = time.strftime("%Y%m%d%H%M%S")
29
28
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
@@ -33,12 +32,12 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
33
32
  class PtdbgDispatch(TorchDispatchMode):
34
33
  def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0):
35
34
  super(PtdbgDispatch, self).__init__()
36
- logger_logo()
35
+ logger.info(COMPARE_LOGO)
37
36
  if not is_npu:
38
- logger_error("Please confirm you run environment installed torch_npu!")
37
+ logger.error("Please confirm you run environment installed torch_npu!")
39
38
  return
40
39
  if dump_path is None:
41
- logger_error("Please set dump_path when dump_mode is config!")
40
+ logger.error("Please set dump_path when dump_mode is config!")
42
41
  check_file_or_directory_path(dump_path, True)
43
42
 
44
43
  self.device_id = torch_npu._C._npu_getDevice()
@@ -49,7 +48,7 @@ class PtdbgDispatch(TorchDispatchMode):
49
48
  self.single_api_index_dict = {}
50
49
  self.device_dump_path_cpu = None
51
50
  self.device_dump_path_npu = None
52
- self.all_summery = []
51
+ self.all_summary = []
53
52
  self.call_stack_list = []
54
53
  self.process_num = process_num
55
54
  self.filter_dump_api()
@@ -70,13 +69,13 @@ class PtdbgDispatch(TorchDispatchMode):
70
69
  self.aten_ops_blacklist = []
71
70
  self.npu_adjust_autogard = []
72
71
  yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
73
- self.load_yaml_file(yaml_path)
72
+ self.get_ops(yaml_path)
74
73
 
75
74
  self.lock = None
76
75
  if process_num > 0:
77
76
  self.pool = Pool(process_num)
78
77
  if debug:
79
- logger_debug(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
78
+ logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
80
79
  f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
81
80
  f'process[{process_num}]')
82
81
 
@@ -85,17 +84,17 @@ class PtdbgDispatch(TorchDispatchMode):
85
84
 
86
85
  if not is_npu:
87
86
  return
88
- logger_debug(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
87
+ logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
89
88
 
90
89
  if self.process_num > 0:
91
90
  self.pool.close()
92
91
  self.pool.join()
93
- summery_path = os.path.join(self.root_cpu_path, f'summary.json')
94
- if not os.path.exists(summery_path):
95
- logger_error("Please check train log, An exception may have occurred!")
92
+ summary_path = os.path.join(self.root_cpu_path, f'summary.json')
93
+ if not os.path.exists(summary_path):
94
+ logger.error("Please check train log, An exception may have occurred!")
96
95
  return
97
- check_file_or_directory_path(summery_path, False)
98
- fp_handle = open(summery_path, "r")
96
+ check_file_or_directory_path(summary_path, False)
97
+ fp_handle = open(summary_path, "r")
99
98
  while True:
100
99
  json_line_data = fp_handle.readline()
101
100
  if json_line_data == '\n':
@@ -103,7 +102,7 @@ class PtdbgDispatch(TorchDispatchMode):
103
102
  if len(json_line_data) == 0:
104
103
  break
105
104
  msg = json.loads(json_line_data)
106
- self.all_summery[msg[0]] = msg[1]
105
+ self.all_summary[msg[0]] = msg[1]
107
106
  fp_handle.close()
108
107
 
109
108
  if self.debug_flag:
@@ -111,20 +110,20 @@ class PtdbgDispatch(TorchDispatchMode):
111
110
  output_num = 0
112
111
  total_num = 0
113
112
 
114
- for list_data in self.all_summery:
113
+ for list_data in self.all_summary:
115
114
  for data in list_data:
116
- logger_debug(f'summery: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
115
+ logger.info(f'summary: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
117
116
  if "_input" in data[CompareConst.NPU_NAME]:
118
117
  input_num = input_num + 1
119
118
  if "_output" in data[CompareConst.NPU_NAME]:
120
119
  output_num = output_num + 1
121
120
  total_num = total_num + 1
122
- logger_debug(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
121
+ logger.info(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
123
122
  f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
124
123
 
125
124
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
126
125
  if not is_npu:
127
- logger_error("Please confirm you run environment installed torch_npu!")
126
+ logger.error("Please confirm you run environment installed torch_npu!")
128
127
  return func(*args, **kwargs)
129
128
 
130
129
  func_name_split_list = func.__name__.split(".")
@@ -132,7 +131,7 @@ class PtdbgDispatch(TorchDispatchMode):
132
131
  try:
133
132
  aten_api_overload_name = func_name_split_list[1]
134
133
  except IndexError:
135
- logger_error(f"Please check the func name {func.__name__}!")
134
+ logger.error(f"Please check the func name {func.__name__}!")
136
135
  return func(*args, **kwargs)
137
136
 
138
137
  self.enable_autogard(aten_api)
@@ -151,7 +150,7 @@ class PtdbgDispatch(TorchDispatchMode):
151
150
  run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
152
151
 
153
152
  if self.debug_flag:
154
- logger_debug(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
153
+ logger.info(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
155
154
  f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
156
155
  f'Count[{self.api_index}], Sys[{get_sys_info()}]')
157
156
 
@@ -175,21 +174,21 @@ class PtdbgDispatch(TorchDispatchMode):
175
174
  cpu_out = cpu_out.float()
176
175
 
177
176
  if self.process_num == 0:
178
- self.all_summery.append([])
179
- data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, func, npu_out_cpu, cpu_out, self.lock)
177
+ self.all_summary.append([])
178
+ data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, func, npu_out_cpu, cpu_out, self.lock)
180
179
  dispatch_workflow(run_param, data_info)
181
180
  else:
182
181
  self.lock.acquire()
183
- self.all_summery.append([])
182
+ self.all_summary.append([])
184
183
  self.lock.release()
185
184
  run_param.process_flag = True
186
185
  if self.check_fun(func, run_param):
187
- data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, None, npu_out_cpu, cpu_out,
186
+ data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
188
187
  self.lock)
189
188
  self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
190
189
  error_callback=error_call)
191
190
  else:
192
- logger_error("can not get correct function please set process_num=0")
191
+ logger.error("can not get correct function please set process_num=0")
193
192
  return npu_out
194
193
 
195
194
  @staticmethod
@@ -208,17 +207,16 @@ class PtdbgDispatch(TorchDispatchMode):
208
207
  time.sleep(1)
209
208
  time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
210
209
  if tag is None or not isinstance(tag, str):
211
- logger_warn('There is not tag or the type of tag is not string.')
210
+ logger.warning('There is not tag or the type of tag is not string.')
212
211
  dir_name = f'msprobe_rank{self.device_id}_{time_now}'
213
212
  else:
214
213
  dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
215
214
  return dir_name
216
215
 
217
- def load_yaml_file(self, file_path):
218
- with FileOpen(file_path, 'r') as f:
219
- yaml_file = yaml.safe_load(f)
220
- self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
221
- self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
216
+ def get_ops(self, file_path):
217
+ yaml_file = load_yaml(file_path)
218
+ self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
219
+ self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
222
220
 
223
221
  def filter_dump_api(self):
224
222
  if self.dump_mode != Const.LIST or not self.dump_api_list:
@@ -230,7 +228,7 @@ class PtdbgDispatch(TorchDispatchMode):
230
228
  if aten_api in aten_api_list:
231
229
  dump_api_list.append(aten_api)
232
230
  else:
233
- logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
231
+ logger.warning(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
234
232
  self.dump_api_list = dump_api_list
235
233
 
236
234
  def get_run_param(self, aten_api, func_name, aten_api_overload_name):
@@ -257,16 +255,16 @@ class PtdbgDispatch(TorchDispatchMode):
257
255
 
258
256
  def check_param(self):
259
257
  if self.dump_mode not in Const.ONLINE_DUMP_MODE:
260
- logger_error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
258
+ logger.error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
261
259
  raise DispatchException(DispatchException.INVALID_PARAMETER)
262
260
  if not isinstance(self.dump_api_list, list):
263
- logger_error('The type of parameter "api_list" can only be list.')
261
+ logger.error('The type of parameter "api_list" can only be list.')
264
262
  raise DispatchException(DispatchException.INVALID_PARAMETER)
265
263
  if not isinstance(self.debug_flag, bool):
266
- logger_error('The type of parameter "debug" can only be bool.')
264
+ logger.error('The type of parameter "debug" can only be bool.')
267
265
  raise DispatchException(DispatchException.INVALID_PARAMETER)
268
266
  if not isinstance(self.process_num, int) or self.process_num < 0:
269
- logger_error('The type of parameter "process_num" can only be int and it should not be less than 0.')
267
+ logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.')
270
268
  raise DispatchException(DispatchException.INVALID_PARAMETER)
271
269
 
272
270
  def enable_autogard(self, aten_api):
@@ -5,11 +5,10 @@ from datetime import datetime, timezone
5
5
 
6
6
  import pandas as pd
7
7
  import torch
8
- from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \
9
- COLOR_RESET, CSV_COLUMN_NAME
10
- from msprobe.core.common.file_check import FileOpen, change_mode
11
- from msprobe.core.common.const import CompareConst, FileCheckConst, Const
12
8
  from msprobe.pytorch.common.log import logger
9
+ from msprobe.core.common.file_check import FileOpen
10
+ from .utils import np_save_data
11
+
13
12
 
14
13
  class DispatchRunParam:
15
14
  def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator):
@@ -32,10 +31,10 @@ class DispatchRunParam:
32
31
 
33
32
 
34
33
  class DisPatchDataInfo:
35
- def __init__(self, cpu_args, cpu_kwargs, all_summery, func, npu_out_cpu, cpu_out, lock):
34
+ def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out, lock):
36
35
  self.cpu_args = cpu_args
37
36
  self.cpu_kwargs = cpu_kwargs
38
- self.all_summery = all_summery
37
+ self.all_summary = all_summary
39
38
  self.func = func
40
39
  self.npu_out_cpu = npu_out_cpu
41
40
  self.cpu_out = cpu_out
@@ -57,7 +56,7 @@ class TimeStatistics:
57
56
  def __enter__(self):
58
57
  if self.debug:
59
58
  self.time = datetime.now(tz=timezone.utc)
60
- logger_debug(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
59
+ logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
61
60
  f'Id[{self.index}]')
62
61
 
63
62
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -68,9 +67,9 @@ class TimeStatistics:
68
67
  hot_time_cost = "Hotspot " + time_cost
69
68
 
70
69
  if cost_time.total_seconds() > self.timeout:
71
- logger_debug(hot_time_cost)
70
+ logger.info(hot_time_cost)
72
71
  else:
73
- logger_debug(time_cost)
72
+ logger.info(time_cost)
74
73
 
75
74
 
76
75
  def support_basic_type(data):
@@ -87,24 +86,24 @@ def dump_data(data, prefix, dump_path):
87
86
  elif support_basic_type(data):
88
87
  if isinstance(data, torch.Tensor) and data.is_meta:
89
88
  return
90
- # dump data may greater than summery_list collect
89
+ # dump data may greater than summary_list collect
91
90
  np_save_data(data, prefix, dump_path)
92
91
 
93
92
 
94
- def save_temp_summery(api_index, single_api_summery, path, lock):
95
- summery_path = os.path.join(path, f'summery.json')
93
+ def save_temp_summary(api_index, single_api_summary, path, lock):
94
+ summary_path = os.path.join(path, f'summary.json')
96
95
  lock.acquire()
97
- with FileOpen(summery_path, "a") as f:
98
- json.dump([api_index, single_api_summery], f)
96
+ with FileOpen(summary_path, "a") as f:
97
+ json.dump([api_index, single_api_summary], f)
99
98
  f.write('\n')
100
99
  lock.release()
101
100
 
102
101
 
103
102
  def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
104
103
  cpu_args, cpu_kwargs = data_info.cpu_args, data_info.cpu_kwargs
105
- all_summery, func = data_info.all_summery, data_info.func
104
+ all_summary, func = data_info.all_summary, data_info.func
106
105
  npu_out_cpu, cpu_out, lock = data_info.npu_out_cpu, data_info.cpu_out, data_info.lock
107
- single_api_summery = []
106
+ single_api_summary = []
108
107
 
109
108
  prefix_input = f'{run_param.aten_api}_{run_param.single_api_index}_input'
110
109
  prefix_output = f'{run_param.aten_api}_{run_param.single_api_index}_output'
@@ -127,9 +126,9 @@ def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
127
126
  dump_data(npu_out_cpu, prefix_output, run_param.root_npu_path)
128
127
 
129
128
  if run_param.process_num == 0:
130
- all_summery[run_param.api_index - 1] = copy.deepcopy(single_api_summery)
129
+ all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary)
131
130
  else:
132
- save_temp_summery(run_param.api_index - 1, single_api_summery, run_param.root_cpu_path, lock)
131
+ save_temp_summary(run_param.api_index - 1, single_api_summary, run_param.root_cpu_path, lock)
133
132
 
134
133
 
135
134
  def get_torch_func(run_param):
@@ -155,32 +154,3 @@ def dispatch_multiprocess(run_param, dispatch_data_info):
155
154
  def error_call(err):
156
155
  logger.error(f'multiprocess {err}')
157
156
 
158
-
159
- def save_csv(all_summery, call_stack_list, csv_path):
160
- df = pd.DataFrame(columns=CSV_COLUMN_NAME)
161
-
162
- for index, list_data in enumerate(all_summery):
163
- for data in list_data:
164
- csv_row_data = {CompareConst.NPU_NAME: data[CompareConst.NPU_NAME],
165
- CompareConst.BENCH_NAME: data[CompareConst.BENCH_NAME],
166
- CompareConst.NPU_DTYPE: data[CompareConst.NPU_DTYPE],
167
- CompareConst.BENCH_DTYPE: data[CompareConst.BENCH_DTYPE],
168
- CompareConst.NPU_SHAPE: data[CompareConst.NPU_SHAPE],
169
- CompareConst.BENCH_SHAPE: data[CompareConst.BENCH_SHAPE],
170
- CompareConst.NPU_MAX: data[CompareConst.NPU_MAX],
171
- CompareConst.NPU_MIN: data[CompareConst.NPU_MIN],
172
- CompareConst.NPU_MEAN: data[CompareConst.NPU_MEAN],
173
- CompareConst.BENCH_MAX: data[CompareConst.BENCH_MAX],
174
- CompareConst.BENCH_MIN: data[CompareConst.BENCH_MIN],
175
- CompareConst.BENCH_MEAN: data[CompareConst.BENCH_MEAN],
176
- CompareConst.COSINE: data[CompareConst.COSINE],
177
- CompareConst.MAX_ABS_ERR: data[CompareConst.MAX_ABS_ERR],
178
- CompareConst.MAX_RELATIVE_ERR: data[CompareConst.MAX_RELATIVE_ERR],
179
- CompareConst.ACCURACY: data[CompareConst.ACCURACY],
180
- CompareConst.STACK: call_stack_list[index],
181
- CompareConst.ERROR_MESSAGE: data[CompareConst.ERROR_MESSAGE]}
182
- row_df = pd.DataFrame.from_dict(csv_row_data, orient='index').T
183
- df = pd.concat([df, row_df])
184
-
185
- df.to_csv(csv_path, index=False)
186
- change_mode(csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -3,15 +3,15 @@ from functools import wraps
3
3
  import torch
4
4
  from prettytable import PrettyTable
5
5
  from collections import namedtuple
6
- from .utils import logger_user, logger_debug
6
+ from msprobe.pytorch.common.log import logger
7
7
 
8
8
  def func_log_wrapper():
9
9
  def _out_wrapper(func):
10
10
  @wraps(func)
11
11
  def _in_wrapper(*kargs, **kwargs):
12
- logger_debug("start to run: {}".format(func.__name__))
12
+ logger.info(f"start to run: {func.__name__}")
13
13
  x = func(*kargs, **kwargs)
14
- logger_debug("end to run: {}".format(func.__name__))
14
+ logger.info(f"end to run: {func.__name__}")
15
15
  return x
16
16
 
17
17
  return _in_wrapper
@@ -165,7 +165,7 @@ class SingleBenchmarkAccuracyCompare:
165
165
  def compute_binary_diff(cls, npu_out, bench_out):
166
166
  result = torch.equal(npu_out, bench_out)
167
167
  if result:
168
- logger_user("二进制精度比对通过, 无需单标杆比对法验证")
168
+ logger.info("二进制精度比对通过, 无需单标杆比对法验证")
169
169
  return SingleBenchmarkAccuracyResult(result=result, max_abs_diff=0, max_rel_diff=0, error_balance=0)
170
170
 
171
171
  @classmethod
@@ -301,7 +301,7 @@ class SingleBenchSummary:
301
301
  table.add_row(["max_rel_diff", self.max_rel_diff, self.error_thd])
302
302
  table.add_row(["max_rel_idx", self.max_rel_idx, "-"])
303
303
 
304
- logger_user(table)
304
+ logger.info(table)
305
305
 
306
306
  def to_column_value(self):
307
307
  return [self.bench_dtype, self.npu_dtype, self.shape, self.error_balance,
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  import inspect
3
- import logging
4
3
  import psutil
5
4
  import torch
6
5
  import numpy as np
@@ -14,6 +13,7 @@ else:
14
13
 
15
14
  from msprobe.core.common.const import CompareConst, FileCheckConst
16
15
  from msprobe.core.common.file_check import change_mode
16
+ from msprobe.core.common.log import logger
17
17
 
18
18
  cpu_device = torch._C.device("cpu")
19
19
  COLOR_RED = '\033[31m'
@@ -77,7 +77,7 @@ def np_save_data(data, file_name, data_path):
77
77
  np.save(dump_path, data)
78
78
  change_mode(dump_path, FileCheckConst.DATA_FILE_AUTHORITY)
79
79
  except Exception as e:
80
- logger_error("save numpy failed, error: {}".format(e))
80
+ logger.error("save numpy failed, error: {}".format(e))
81
81
  finally:
82
82
  pass
83
83
 
@@ -124,47 +124,6 @@ def data_to_cpu(data, deep, data_cpu):
124
124
  return data
125
125
 
126
126
 
127
- def get_mp_logger():
128
- logger = logging.getLogger(__name__)
129
- if not logger.handlers:
130
- logger.setLevel(logging.INFO)
131
- handler = logging.StreamHandler()
132
- formatter = logging.Formatter('%(asctime)s %(message)s')
133
- logger.propagate = True
134
- handler.setFormatter(formatter)
135
- logger.addHandler(handler)
136
- return logger.info
137
-
138
-
139
- def logger_debug(mesg):
140
- logger = get_mp_logger()
141
- logger(f'DEBUG ' + mesg)
142
-
143
-
144
- def logger_info(mesg):
145
- logger = get_mp_logger()
146
- logger(f'INFO ' + mesg)
147
-
148
-
149
- def logger_warn(mesg):
150
- logger = get_mp_logger()
151
- logger(f'{COLOR_YELLOW}WARNING {mesg} {COLOR_RESET}')
152
-
153
-
154
- def logger_error(mesg):
155
- logger = get_mp_logger()
156
- logger(f'{COLOR_RED}ERROR {mesg} {COLOR_RESET}')
157
-
158
-
159
- def logger_user(mesg):
160
- logger = get_mp_logger()
161
- logger(mesg)
162
-
163
-
164
- def logger_logo():
165
- logger_user(f'{COLOR_CYAN}{COMPARE_LOGO} {COLOR_RESET}')
166
-
167
-
168
127
  def get_sys_info():
169
128
  mem = psutil.virtual_memory()
170
129
  cpu_percent = psutil.cpu_percent(interval=1)