mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,98 @@
1
+ from functools import wraps
2
+ import torch
3
+ from torch.utils.hooks import BackwardHook
4
+ from msprobe.core.common.const import Const
5
+ from msprobe.core.data_dump.scope import ModuleRangeScope
6
+
7
+
8
+ class ModuleProcesser:
9
+ module_stack = []
10
+ api_parent_node = ""
11
+ module_node = {}
12
+ current_module_name = ""
13
+
14
+ def __init__(self, scope):
15
+ if isinstance(scope, ModuleRangeScope):
16
+ self.scope = scope
17
+ else:
18
+ self.scope = None
19
+ BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
20
+ BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
21
+ BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
22
+ self.module_count = {}
23
+
24
+ @staticmethod
25
+ def filter_tensor_and_tuple(func):
26
+ @wraps(func)
27
+ def wrap_by_filter_tensor_and_tuple(*args, **kwargs):
28
+ # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入
29
+ # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1]
30
+ if not isinstance(args[1], (torch.Tensor, tuple)):
31
+ return args[1]
32
+ return func(*args, **kwargs)
33
+
34
+ return wrap_by_filter_tensor_and_tuple
35
+
36
+ @staticmethod
37
+ def clone_return_value(func):
38
+ @wraps(func)
39
+ def clone_return_value_func(*args, **kwargs):
40
+ result = func(*args, **kwargs)
41
+ return ModuleProcesser.clone_if_tensor(result)
42
+
43
+ return clone_return_value_func
44
+
45
+ @staticmethod
46
+ def clone_if_tensor(result):
47
+ if isinstance(result, torch.Tensor):
48
+ return result.clone()
49
+ elif isinstance(result, tuple):
50
+ return tuple(ModuleProcesser.clone_if_tensor(x) for x in result)
51
+ elif isinstance(result, list):
52
+ return list(ModuleProcesser.clone_if_tensor(x) for x in result)
53
+ elif isinstance(result, dict):
54
+ return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()}
55
+ else:
56
+ return result
57
+
58
+ def node_hook(self, name_prefix, start_or_stop, **kwargs):
59
+
60
+ def pre_hook(module, input, output=None):
61
+ try:
62
+ index = self.module_count_func(name_prefix)
63
+ except IndexError as e:
64
+ index = None
65
+ pass
66
+ module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index)
67
+ if self.module_stack:
68
+ ModuleProcesser.module_node[full_name] = self.module_stack[-1]
69
+ else:
70
+ ModuleProcesser.module_node[full_name] = None
71
+
72
+ ModuleProcesser.module_stack.append(full_name)
73
+ if self.module_stack:
74
+ ModuleProcesser.api_parent_node = self.module_stack[-1]
75
+ if self.scope:
76
+ self.scope.begin_module(full_name)
77
+
78
+ def end_hook(module, input, output=None):
79
+ if self.module_stack:
80
+ ModuleProcesser.module_stack.pop()
81
+ if self.module_stack:
82
+ ModuleProcesser.api_parent_node = self.module_stack[-1]
83
+ else:
84
+ ModuleProcesser.api_parent_node = None
85
+ if self.scope:
86
+ self.scope.end_module(module.mindstudio_reserved_name)
87
+
88
+ if Const.START in start_or_stop:
89
+ return pre_hook
90
+ else:
91
+ return end_hook
92
+
93
+ def module_count_func(self, module_name):
94
+ if module_name not in self.module_count:
95
+ self.module_count[module_name] = 0
96
+ else:
97
+ self.module_count[module_name] += 1
98
+ return self.module_count[module_name]
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2024-2024 Huawei Technologies Co., Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from signal import signal, SIGPIPE, SIG_DFL
16
+ from .dispatch import PtdbgDispatch
17
+ signal(SIGPIPE, SIG_DFL)
18
+
19
+
20
+ __all__ = ["PtdbgDispatch"]
@@ -0,0 +1,236 @@
1
+ # 进行比对及结果展示
2
+ import os
3
+ import sys
4
+ import csv
5
+ import json
6
+ from collections import namedtuple
7
+ from rich.table import Table
8
+ from rich.console import Console
9
+ from .single_compare import single_benchmark_compare_wrap
10
+ from .utils import DispatchException
11
+ from msprobe.core.common.const import CompareConst
12
+ from msprobe.core.common.file_check import FileOpen
13
+ from msprobe.pytorch.common.log import logger
14
+ from msprobe.core.common.utils import CompareException
15
+
16
+ ELEMENT_NUM_THRESHOLD = 100
17
+ ZERO_NUM_THRESHOLD = 0.1
18
+ FLOAT_PRECISION = 14
19
+
20
+ ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_success',
21
+ 'fwd_compare_alg_results', 'bwd_compare_alg_results'])
22
+
23
+ def get_file_content_bytes(file):
24
+ with FileOpen(file, 'rb') as file_handle:
25
+ return file_handle.read()
26
+
27
+
28
+ def get_json_contents(file_path):
29
+ ops = get_file_content_bytes(file_path)
30
+ try:
31
+ json_obj = json.loads(ops)
32
+ except ValueError as error:
33
+ logger.error('Failed to load "%s". %s' % (file_path, str(error)))
34
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from error
35
+ if not isinstance(json_obj, dict):
36
+ logger.error('Json file %s, content is not a dictionary!' % file_path)
37
+ raise CompareException(CompareException.INVALID_FILE_ERROR)
38
+ return json_obj
39
+
40
+
41
+ def write_csv(data, filepath):
42
+ with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
43
+ writer = csv.writer(f)
44
+ writer.writerows(data)
45
+
46
+
47
+ class Saver:
48
+ # consts for result csv
49
+ COLUMN_API_NAME = "API name"
50
+ COLUMN_FORWARD_SUCCESS = "Forward Test Success"
51
+ COLUMN_BACKWARD_SUCCESS = "Backward Test Success"
52
+ COLUMN_STACK_INFO = "Traceback callstack info"
53
+
54
+ def __init__(self, save_path, detail_save_path, stack_info):
55
+ self.save_path = save_path
56
+ self.detail_save_path = detail_save_path
57
+ self.stack_info = stack_info
58
+
59
+ self.test_result_cnt = {
60
+ "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
61
+ "total_num": 0, "forward_or_backward_fail_num": 0
62
+ }
63
+
64
+ def write_csv_title(self):
65
+ summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]]
66
+ write_csv(summary_test_rows, self.save_path)
67
+
68
+ detail_test_rows = [[
69
+ "Npu Name", "Bench Dtype", "NPU Dtype", "Shape",
70
+ "error_balance", "max_abs_diff", "max_abs_idx",
71
+ "max_rel_diff", "max_rel_idx", "eb_thd",
72
+ "error_thd", "Status","Message"
73
+ ]]
74
+ write_csv(detail_test_rows, self.detail_save_path)
75
+
76
+ def print_pretest_result(self):
77
+ self.get_statistics_from_result_csv()
78
+ if self.test_result_cnt.get("total_num") != 0:
79
+ passing_rate = str(self.test_result_cnt.get("success_num") /
80
+ (self.test_result_cnt.get("total_num") + sys.float_info.epsilon))
81
+ else:
82
+ passing_rate = "0"
83
+
84
+ console = Console()
85
+ table_total = Table(
86
+ show_header=True, title="Overall Statistics", show_lines=True, width=75
87
+ )
88
+ table_total.add_column("Result")
89
+ table_total.add_column("Statistics")
90
+ table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num")))
91
+ table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") +
92
+ self.test_result_cnt.get("forward_or_backward_fail_num")))
93
+ table_total.add_row("Passing Rate", passing_rate)
94
+
95
+ table_detail = Table(
96
+ show_header=True, title="Detail Statistics", show_lines=True, width=75
97
+ )
98
+ table_detail.add_column("Result")
99
+ table_detail.add_column("Statistics")
100
+ table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num")))
101
+ table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num")))
102
+ table_detail.add_row(
103
+ "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num")))
104
+
105
+ console.print(table_total)
106
+ console.print(table_detail)
107
+
108
+ def get_statistics_from_result_csv(self):
109
+ checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP]
110
+ with FileOpen(self.save_path, 'r') as file:
111
+ reader = csv.reader(file)
112
+ result_csv_rows = [row for row in reader]
113
+ result_csv_name = os.path.basename(self.save_path)
114
+ for item in result_csv_rows[1:]:
115
+ if not isinstance(item, list) or len(item) < 3:
116
+ raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
117
+ if not all(item[i] and item[i].upper() in checklist for i in (1, 2)):
118
+ raise ValueError(
119
+ "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A"
120
+ % result_csv_name)
121
+ column1 = item[1].upper()
122
+ column2 = item[2].upper()
123
+ if column1 == CompareConst.SKIP:
124
+ continue
125
+ self.test_result_cnt["total_num"] += 1
126
+ if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']:
127
+ self.test_result_cnt['success_num'] += 1
128
+ elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE:
129
+ self.test_result_cnt['forward_and_backward_fail_num'] += 1
130
+ elif column1 == CompareConst.FALSE:
131
+ self.test_result_cnt['forward_fail_num'] += 1
132
+ self.test_result_cnt['forward_or_backward_fail_num'] += 1
133
+ else:
134
+ self.test_result_cnt['backward_fail_num'] += 1
135
+ self.test_result_cnt['forward_or_backward_fail_num'] += 1
136
+
137
+ def write_summary_csv(self, test_result):
138
+ test_rows = []
139
+ if self.stack_info:
140
+ test_rows[0].append(self.COLUMN_STACK_INFO)
141
+
142
+ name = test_result.api_name
143
+ df_row = [test_result.api_name, test_result.is_fwd_success, test_result.is_bwd_success]
144
+ if test_result.is_fwd_success == "SKIP" or test_result.is_bwd_success == "SKIP":
145
+ df_row.append(test_result.fwd_compare_alg_results)
146
+ if self.stack_info:
147
+ stack_info = "\n".join(self.stack_info[name])
148
+ df_row.append(stack_info)
149
+ test_rows.append(df_row)
150
+ write_csv(test_rows, self.save_path)
151
+
152
+ def write_detail_csv(self, test_result):
153
+ def get_rows_from_list(result, name, sub_prefix):
154
+ rows = []
155
+ if isinstance(result, list):
156
+ for i, test_subject in enumerate(result):
157
+ subject = sub_prefix + "." + name + ".output." + str(i)
158
+ test_subject = ["{:.{}f}".format(item, FLOAT_PRECISION) if isinstance(item, float) else item for
159
+ item in test_subject]
160
+ rows.append([subject] + list(test_subject))
161
+ return rows
162
+
163
+ test_rows = []
164
+ subject_prefix = test_result.api_name
165
+ fwd_result = test_result.fwd_compare_alg_results
166
+ bwd_result = test_result.bwd_compare_alg_results
167
+
168
+ test_rows.extend(get_rows_from_list(fwd_result, "forward", subject_prefix))
169
+ test_rows.extend(get_rows_from_list(bwd_result, "backward", subject_prefix))
170
+
171
+ write_csv(test_rows, self.detail_save_path)
172
+
173
+ def record_results(self, result_info):
174
+ self.write_summary_csv(result_info)
175
+ self.write_detail_csv(result_info)
176
+
177
+
178
+ class Comparator:
179
+
180
+ def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None):
181
+ self.save_path = result_csv_path
182
+ self.detail_save_path = details_csv_path
183
+ if stack_info_json_path:
184
+ self.stack_info = get_json_contents(stack_info_json_path)
185
+ else:
186
+ self.stack_info = None
187
+ self.saver = Saver(result_csv_path, details_csv_path, self.stack_info)
188
+
189
+ if is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path):
190
+ self.saver.write_csv_title()
191
+
192
+ @staticmethod
193
+ def _compare_core_wrapper(bench_out, npu_out):
194
+ detailed_result_total = []
195
+ test_final_success = True
196
+ status, details = single_benchmark_compare_wrap(npu_out, bench_out)
197
+ if not isinstance(status, list):
198
+ detailed_result_total.append(details)
199
+ test_final_success = status
200
+ else:
201
+ for item, item_status in enumerate(status):
202
+ detailed_result_total.append(details.get(item, 'key does not exist'))
203
+ if not item_status:
204
+ test_final_success = False
205
+ return test_final_success, detailed_result_total
206
+
207
+ @staticmethod
208
+ def _compare_dropout(bench_out, npu_out):
209
+ tensor_num = bench_out.numel()
210
+ if tensor_num >= ELEMENT_NUM_THRESHOLD:
211
+ if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < ZERO_NUM_THRESHOLD:
212
+ return True, 1
213
+ else:
214
+ return False, 0
215
+ else:
216
+ return True, 1
217
+
218
+ def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None):
219
+ if "dropout" in api_name:
220
+ is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out)
221
+ else:
222
+ is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out)
223
+ if bench_grad and npu_grad:
224
+ if "dropout" in api_name:
225
+ is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0])
226
+ else:
227
+ is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad)
228
+ else:
229
+ is_bwd_success, bwd_compare_alg_results = True, None
230
+ if is_bwd_success and bwd_compare_alg_results is None:
231
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results,
232
+ bwd_compare_alg_results))
233
+ else:
234
+ self.saver.record_results(ResultInfo(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results,
235
+ bwd_compare_alg_results))
236
+ return is_fwd_success, is_bwd_success
@@ -0,0 +1,274 @@
1
+ import os
2
+ import time
3
+ import json
4
+ from pathlib import Path
5
+ from multiprocessing import Manager, Pool
6
+
7
+ import yaml
8
+ import torch
9
+
10
+ from torch.utils._python_dispatch import TorchDispatchMode
11
+
12
+ try:
13
+ import torch_npu
14
+ except ImportError:
15
+ is_npu = False
16
+ else:
17
+ is_npu = True
18
+
19
+ from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
20
+ DispatchRunParam, DisPatchDataInfo
21
+ from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \
22
+ DispatchException
23
+ from .compare import Comparator
24
+ from msprobe.core.common.file_check import FileOpen
25
+ from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create
26
+ from msprobe.core.common.const import Const, CompareConst
27
+
28
+ current_time = time.strftime("%Y%m%d%H%M%S")
29
+ RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
30
+ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
31
+
32
+
33
+ class PtdbgDispatch(TorchDispatchMode):
34
+ def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0):
35
+ super(PtdbgDispatch, self).__init__()
36
+ logger_logo()
37
+ if not is_npu:
38
+ logger_error("Please confirm you run environment installed torch_npu!")
39
+ return
40
+ if dump_path is None:
41
+ logger_error("Please set dump_path when dump_mode is config!")
42
+ check_file_or_directory_path(dump_path, True)
43
+
44
+ self.device_id = torch_npu._C._npu_getDevice()
45
+ self.dump_mode = dump_mode
46
+ self.dump_api_list = api_list
47
+ self.debug_flag = debug
48
+ self.api_index = 0
49
+ self.single_api_index_dict = {}
50
+ self.device_dump_path_cpu = None
51
+ self.device_dump_path_npu = None
52
+ self.all_summery = []
53
+ self.call_stack_list = []
54
+ self.process_num = process_num
55
+ self.filter_dump_api()
56
+ self.check_param()
57
+ dir_name = self.get_dir_name(tag)
58
+ self.root_path = os.path.join(os.path.realpath(dump_path), dir_name)
59
+ self.root_cpu_path = os.path.join(self.root_path, f'cpu')
60
+ self.root_npu_path = os.path.join(self.root_path, f'npu')
61
+ check_path_before_create(self.root_cpu_path)
62
+ check_path_before_create(self.root_npu_path)
63
+ Path(self.root_cpu_path).mkdir(mode=0o750, parents=True, exist_ok=True)
64
+ Path(self.root_npu_path).mkdir(mode=0o750, parents=True, exist_ok=True)
65
+
66
+ self.result_csv_path = os.path.join(self.root_path, RESULT_FILE_NAME)
67
+ self.detail_csv_path = os.path.join(self.root_path, DETAILS_FILE_NAME)
68
+ self.comparator = Comparator(self.result_csv_path, self.detail_csv_path, False)
69
+
70
+ self.aten_ops_blacklist = []
71
+ self.npu_adjust_autogard = []
72
+ yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
73
+ self.load_yaml_file(yaml_path)
74
+
75
+ self.lock = None
76
+ if process_num > 0:
77
+ self.pool = Pool(process_num)
78
+ if debug:
79
+ logger_debug(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
80
+ f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
81
+ f'process[{process_num}]')
82
+
83
+ def __exit__(self, exc_type, exc_val, exc_tb):
84
+ super().__exit__(exc_type, exc_val, exc_tb)
85
+
86
+ if not is_npu:
87
+ return
88
+ logger_debug(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}')
89
+
90
+ if self.process_num > 0:
91
+ self.pool.close()
92
+ self.pool.join()
93
+ summery_path = os.path.join(self.root_cpu_path, f'summary.json')
94
+ if not os.path.exists(summery_path):
95
+ logger_error("Please check train log, An exception may have occurred!")
96
+ return
97
+ check_file_or_directory_path(summery_path, False)
98
+ fp_handle = open(summery_path, "r")
99
+ while True:
100
+ json_line_data = fp_handle.readline()
101
+ if json_line_data == '\n':
102
+ continue
103
+ if len(json_line_data) == 0:
104
+ break
105
+ msg = json.loads(json_line_data)
106
+ self.all_summery[msg[0]] = msg[1]
107
+ fp_handle.close()
108
+
109
+ if self.debug_flag:
110
+ input_num = 0
111
+ output_num = 0
112
+ total_num = 0
113
+
114
+ for list_data in self.all_summery:
115
+ for data in list_data:
116
+ logger_debug(f'summery: Device[{self.device_id}], Pid[{os.getpid()}], Data[{data}]')
117
+ if "_input" in data[CompareConst.NPU_NAME]:
118
+ input_num = input_num + 1
119
+ if "_output" in data[CompareConst.NPU_NAME]:
120
+ output_num = output_num + 1
121
+ total_num = total_num + 1
122
+ logger_debug(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
123
+ f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
124
+
125
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
126
+ if not is_npu:
127
+ logger_error("Please confirm you run environment installed torch_npu!")
128
+ return func(*args, **kwargs)
129
+
130
+ func_name_split_list = func.__name__.split(".")
131
+ aten_api = func_name_split_list[0]
132
+ try:
133
+ aten_api_overload_name = func_name_split_list[1]
134
+ except IndexError:
135
+ logger_error(f"Please check the func name {func.__name__}!")
136
+ return func(*args, **kwargs)
137
+
138
+ self.enable_autogard(aten_api)
139
+ if aten_api in self.aten_ops_blacklist:
140
+ npu_out = func(*args, **kwargs)
141
+ return npu_out
142
+
143
+ call_stack = get_callstack()
144
+ self.call_stack_list.append(call_stack)
145
+ self.api_index += 1
146
+ if aten_api not in self.single_api_index_dict:
147
+ self.single_api_index_dict[aten_api] = 1
148
+ else:
149
+ self.single_api_index_dict[aten_api] += 1
150
+
151
+ run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
152
+
153
+ if self.debug_flag:
154
+ logger_debug(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
155
+ f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
156
+ f'Count[{self.api_index}], Sys[{get_sys_info()}]')
157
+
158
+ cpu_args = []
159
+ cpu_kwargs = []
160
+ data_to_cpu(args, 0, cpu_args)
161
+ data_to_cpu(kwargs, 0, cpu_kwargs)
162
+ cpu_args = cpu_args[0]
163
+ cpu_kwargs = cpu_kwargs[0]
164
+
165
+ with TimeStatistics("NPU RUN", run_param):
166
+ npu_out = func(*args, **kwargs)
167
+ npu_out_cpu = []
168
+ data_to_cpu(npu_out, 0, npu_out_cpu)
169
+ npu_out_cpu = npu_out_cpu[0]
170
+
171
+ with TimeStatistics("CPU RUN", run_param):
172
+ cpu_out = func(*cpu_args, **cpu_kwargs)
173
+
174
+ if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]:
175
+ cpu_out = cpu_out.float()
176
+
177
+ if self.process_num == 0:
178
+ self.all_summery.append([])
179
+ data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, func, npu_out_cpu, cpu_out, self.lock)
180
+ dispatch_workflow(run_param, data_info)
181
+ else:
182
+ self.lock.acquire()
183
+ self.all_summery.append([])
184
+ self.lock.release()
185
+ run_param.process_flag = True
186
+ if self.check_fun(func, run_param):
187
+ data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summery, None, npu_out_cpu, cpu_out,
188
+ self.lock)
189
+ self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
190
+ error_callback=error_call)
191
+ else:
192
+ logger_error("can not get correct function please set process_num=0")
193
+ return npu_out
194
+
195
+ @staticmethod
196
+ def check_fun(func, run_param):
197
+ if hasattr(torch.ops.aten, run_param.aten_api):
198
+ aten_func = getattr(torch.ops.aten, run_param.aten_api)
199
+ if hasattr(aten_func, run_param.aten_api_overload_name):
200
+ aten_overload_func = getattr(aten_func, run_param.aten_api_overload_name)
201
+ if id(aten_overload_func) == id(func):
202
+ run_param.func_namespace = "aten"
203
+ return True
204
+ return False
205
+
206
+ def get_dir_name(self, tag):
207
+ # guarantee file uniqueness
208
+ time.sleep(1)
209
+ time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
210
+ if tag is None or not isinstance(tag, str):
211
+ logger_warn('There is not tag or the type of tag is not string.')
212
+ dir_name = f'msprobe_rank{self.device_id}_{time_now}'
213
+ else:
214
+ dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
215
+ return dir_name
216
+
217
+ def load_yaml_file(self, file_path):
218
+ with FileOpen(file_path, 'r') as f:
219
+ yaml_file = yaml.safe_load(f)
220
+ self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
221
+ self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
222
+
223
+ def filter_dump_api(self):
224
+ if self.dump_mode != Const.LIST or not self.dump_api_list:
225
+ self.dump_api_list = []
226
+ return
227
+ aten_api_list = dir(torch.ops.aten)
228
+ dump_api_list = []
229
+ for aten_api in self.dump_api_list:
230
+ if aten_api in aten_api_list:
231
+ dump_api_list.append(aten_api)
232
+ else:
233
+ logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten')
234
+ self.dump_api_list = dump_api_list
235
+
236
+ def get_run_param(self, aten_api, func_name, aten_api_overload_name):
237
+ run_param = DispatchRunParam(self.debug_flag, self.device_id, self.root_npu_path, self.root_cpu_path,
238
+ self.process_num, self.comparator)
239
+ run_param.dump_flag, run_param.auto_dump_flag = self.get_dump_flag(aten_api)
240
+ run_param.func_name = func_name
241
+ run_param.aten_api = aten_api
242
+ run_param.aten_api_overload_name = aten_api_overload_name
243
+ run_param.single_api_index = self.single_api_index_dict[aten_api]
244
+ run_param.api_index = self.api_index
245
+ return run_param
246
+
247
+ def get_dump_flag(self, aten_api):
248
+ dump_flag = False
249
+ auto_dump_flag = False
250
+ if self.dump_mode == Const.ALL:
251
+ dump_flag = True
252
+ if self.dump_mode == Const.LIST and aten_api in self.dump_api_list:
253
+ dump_flag = True
254
+ if self.dump_mode == Const.AUTO:
255
+ auto_dump_flag = True
256
+ return dump_flag, auto_dump_flag
257
+
258
+ def check_param(self):
259
+ if self.dump_mode not in Const.ONLINE_DUMP_MODE:
260
+ logger_error('The parameter "dump mode" can only be one of {}.'.format(Const.ONLINE_DUMP_MODE))
261
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
262
+ if not isinstance(self.dump_api_list, list):
263
+ logger_error('The type of parameter "api_list" can only be list.')
264
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
265
+ if not isinstance(self.debug_flag, bool):
266
+ logger_error('The type of parameter "debug" can only be bool.')
267
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
268
+ if not isinstance(self.process_num, int) or self.process_num < 0:
269
+ logger_error('The type of parameter "process_num" can only be int and it should not be less than 0.')
270
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
271
+
272
+ def enable_autogard(self, aten_api):
273
+ if aten_api in self.npu_adjust_autogard:
274
+ torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)