mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,59 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+
18
+
19
+ class AdvisorConst:
20
+ """
21
+ Class for advisor const
22
+ """
23
+
24
+ # text symbol
25
+ NEW_LINE = "\n"
26
+ COLON = ": "
27
+
28
+ # advisor summary key
29
+ SUSPECT_NODES = "Suspect Nodes"
30
+ LINE = "Line"
31
+ ADVISOR_SUGGEST = "Expert Advice"
32
+
33
+ NO_ERROR_API = "NA"
34
+
35
+ # advisor message
36
+ NO_ERR_SUGGEST = "All data in comparison result meets the accuracy requirements."
37
+ FORWARD_INPUT_SUGGEST = "1. Analyze the model to view the input source.\n" \
38
+ "2. Check whether an inplace API causes the output result to overwrite the input result. That is, the fault is actually caused by a computation error.\n" \
39
+ "3. The fault may be caused by memory corruption and further analysis is required."
40
+ FORWARD_OUTPUT_SUGGEST = "This is a forward API computation error. Check the computation implementation."
41
+ BACKWARD_INPUT_SUGGEST = "Check whether the forward computation result is affected."
42
+ BACKWARD_OUTPUT_SUGGEST = "This is a backward API computation error. Check the computation implementation."
43
+ BATCH_NORM_SUGGEST = "Torch API batch_norm input not fixed, the following suggestions may fix it:\n" \
44
+ "1. If use torch.nn.functional.batch_norm, you can set parameter training=False.\n" \
45
+ "2. If use torch.nn.BatchNormXXX, you can set parameter affine=False.\n" \
46
+ "3. Use seed_all(mode=True) to enable deterministic computing."
47
+ DETERMINISTIC_SUGGEST = "This torch api may be uncertainty in the calculation, " \
48
+ "can seed_all(mode=True) to enable deterministic computing."
49
+
50
+ FUNC_BATCH_NORM = "Functional_batch_norm"
51
+ FORWARD_INPUT_1 = "forward_input.1"
52
+ NEED_DETERMINISTIC_API = ["conv2d", "conv3d", "matmul", "nll_loss", "layer_norm", "lstm"]
53
+ BATCH_NORM = "batch_norm"
54
+
55
+ # name keyword
56
+ INPUT = "input"
57
+ OUTPUT = "output"
58
+ FORWARD = "forward"
59
+ BACKWARD = "backward"
@@ -0,0 +1,58 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import os
18
+ import time
19
+
20
+ from msprobe.pytorch.advisor.advisor_const import AdvisorConst
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.core.common.const import Const, FileCheckConst
23
+ from msprobe.core.common.file_check import change_mode
24
+
25
+
26
+ class AdvisorResult:
27
+ """
28
+ Class for generate advisor result
29
+ """
30
+
31
+ def __init__(self, node, line, message):
32
+ self.suspect_node = node
33
+ self.line = line
34
+ self.advisor_message = message
35
+
36
+ @staticmethod
37
+ def gen_summary_file(out_path, message_list):
38
+ file_name = 'advisor_{}.txt'.format(time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
39
+ result_file = os.path.join(out_path, file_name)
40
+ try:
41
+ with os.fdopen(os.open(result_file, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as output_file:
42
+ output_file.truncate(0)
43
+ message_list = [message + AdvisorConst.NEW_LINE for message in message_list]
44
+ output_file.writelines(message_list)
45
+ change_mode(result_file, FileCheckConst.DATA_FILE_AUTHORITY)
46
+ except IOError as io_error:
47
+ logger.error("Failed to save %s, the reason is %s." % (result_file, io_error))
48
+ else:
49
+ logger.info("The advisor summary is saved in: %s" % result_file)
50
+
51
+ def print_advisor_log(self):
52
+ logger.info("The summary of the expert advice is as follows: ")
53
+ message_list = [AdvisorConst.LINE + AdvisorConst.COLON + str(self.line),
54
+ AdvisorConst.SUSPECT_NODES + AdvisorConst.COLON + self.suspect_node,
55
+ AdvisorConst.ADVISOR_SUGGEST + AdvisorConst.COLON + self.advisor_message]
56
+ for message in message_list:
57
+ logger.info(message)
58
+ return message_list
File without changes
File without changes
File without changes
@@ -0,0 +1,50 @@
1
+ import os
2
+ import yaml
3
+ from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path
4
+ from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps
5
+ from msprobe.core.common.file_check import FileOpen
6
+
7
+ WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps)
8
+
9
+
10
+ class Config:
11
+ def __init__(self, yaml_file):
12
+ check_file_or_directory_path(yaml_file, False)
13
+ with FileOpen(yaml_file, 'r') as file:
14
+ config = yaml.safe_load(file)
15
+ self.config = {key: self.validate(key, value) for key, value in config.items()}
16
+
17
+ def __getattr__(self, item):
18
+ return self.config[item]
19
+
20
+ def __str__(self):
21
+ return '\n'.join(f"{key}={value}" for key, value in self.config.items())
22
+
23
+ @staticmethod
24
+ def validate(key, value):
25
+ validators = {
26
+ 'white_list': list,
27
+ 'error_data_path': str,
28
+ 'precision': int
29
+ }
30
+ if key not in validators:
31
+ raise ValueError(f"{key} must be one of {validators.keys()}")
32
+ if not isinstance(value, validators.get(key)):
33
+ raise ValueError(f"{key} must be {validators[key].__name__} type")
34
+ if key == 'precision' and value < 0:
35
+ raise ValueError("precision must be greater than 0")
36
+ if key == 'white_list':
37
+ if not isinstance(value, list):
38
+ raise ValueError("white_list must be a list type")
39
+ if not all(isinstance(i, str) for i in value):
40
+ raise ValueError("All elements in white_list must be of str type")
41
+ invalid_api = [i for i in value if i not in WrapApi]
42
+ if invalid_api:
43
+ raise ValueError(
44
+ f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list")
45
+ return value
46
+
47
+
48
+ cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
49
+ yaml_path = os.path.join(cur_path, "config.yaml")
50
+ msCheckerConfig = Config(yaml_path)
@@ -0,0 +1,224 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import json
18
+ import os
19
+ import re
20
+ import csv
21
+
22
+ import torch
23
+
24
+ try:
25
+ import torch_npu
26
+ except ImportError:
27
+ IS_GPU = True
28
+ else:
29
+ IS_GPU = False
30
+
31
+ from msprobe.pytorch.common.log import logger
32
+ from msprobe.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory
33
+ from msprobe.core.common.const import Const, FileCheckConst
34
+ from msprobe.core.common.utils import CompareException
35
+
36
+
37
+ class DumpException(CompareException):
38
+ pass
39
+
40
+
41
+ def write_csv(data, filepath):
42
+ with FileOpen(filepath, 'a', encoding='utf-8-sig') as f:
43
+ writer = csv.writer(f)
44
+ writer.writerows(data)
45
+
46
+
47
+ def check_object_type(check_object, allow_type):
48
+ """
49
+ Function Description:
50
+ Check if the object belongs to a certain data type
51
+ Parameter:
52
+ check_object: the object to be checked
53
+ allow_type: legal data type
54
+ Exception Description:
55
+ when invalid data throw exception
56
+ """
57
+ if not isinstance(check_object, allow_type):
58
+ logger.error(f"{check_object} not of {allow_type} type")
59
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
60
+
61
+
62
+ def check_file_or_directory_path(path, isdir=False):
63
+ """
64
+ Function Description:
65
+ check whether the path is valid
66
+ Parameter:
67
+ path: the path to check
68
+ isdir: the path is dir or file
69
+ Exception Description:
70
+ when invalid data throw exception
71
+ """
72
+ if isdir:
73
+ if not os.path.exists(path):
74
+ logger.error('The path {} is not exist.'.format(path))
75
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
76
+
77
+ if not os.path.isdir(path):
78
+ logger.error('The path {} is not a directory.'.format(path))
79
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
80
+
81
+ if not os.access(path, os.W_OK):
82
+ logger.error(
83
+ 'The path {} does not have permission to write. Please check the path permission'.format(path))
84
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
85
+ else:
86
+ if not os.path.isfile(path):
87
+ logger.error('{} is an invalid file or non-exist.'.format(path))
88
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
89
+
90
+ if not os.access(path, os.R_OK):
91
+ logger.error(
92
+ 'The path {} does not have permission to read. Please check the path permission'.format(path))
93
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
94
+
95
+
96
+ def get_json_contents(file_path):
97
+ ops = get_file_content_bytes(file_path)
98
+ try:
99
+ json_obj = json.loads(ops)
100
+ except ValueError as error:
101
+ logger.error('Failed to load "%s". %s' % (file_path, str(error)))
102
+ raise CompareException(CompareException.INVALID_FILE_ERROR) from error
103
+ if not isinstance(json_obj, dict):
104
+ logger.error('Json file %s, content is not a dictionary!' % file_path)
105
+ raise CompareException(CompareException.INVALID_FILE_ERROR)
106
+ return json_obj
107
+
108
+
109
+ def get_file_content_bytes(file):
110
+ with FileOpen(file, 'rb') as file_handle:
111
+ return file_handle.read()
112
+
113
+
114
+ class SoftlinkCheckException(Exception):
115
+ pass
116
+
117
+
118
+ def check_need_convert(api_name):
119
+ convert_type = None
120
+ for key, value in Const.CONVERT_API.items():
121
+ if api_name not in value:
122
+ continue
123
+ else:
124
+ convert_type = key
125
+ return convert_type
126
+
127
+
128
+ def api_info_preprocess(api_name, api_info_dict):
129
+ """
130
+ Function Description:
131
+ Preprocesses the API information.
132
+ Parameter:
133
+ api_name: Name of the API.
134
+ api_info_dict: argument of the API.
135
+ Return api_info_dict:
136
+ convert_type: Type of conversion.
137
+ api_info_dict: Processed argument of the API.
138
+ """
139
+ convert_type = check_need_convert(api_name)
140
+ if api_name == 'cross_entropy':
141
+ api_info_dict = cross_entropy_process(api_info_dict)
142
+ return convert_type, api_info_dict
143
+
144
+
145
+ def cross_entropy_process(api_info_dict):
146
+ """
147
+ Function Description:
148
+ Preprocesses the cross_entropy API information.
149
+ Parameter:
150
+ api_info_dict: argument of the API.
151
+ Return api_info_dict:
152
+ api_info_dict: Processed argument of the API.
153
+ """
154
+ if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]:
155
+ if api_info_dict['args'][1]['Min'] <= 0:
156
+ # The second argument in cross_entropy should be -100 or not less than 0
157
+ api_info_dict['args'][1]['Min'] = 0
158
+ return api_info_dict
159
+
160
+
161
+ def initialize_save_path(save_path, dir_name):
162
+ data_path = os.path.join(save_path, dir_name)
163
+ if os.path.exists(data_path):
164
+ logger.warning(f"{data_path} already exists, it will be overwritten")
165
+ else:
166
+ os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
167
+ data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
168
+ data_path_checker.common_check()
169
+
170
+
171
+ def write_pt(file_path, tensor):
172
+ if os.path.exists(file_path):
173
+ raise ValueError(f"File {file_path} already exists")
174
+ torch.save(tensor, file_path)
175
+ full_path = os.path.realpath(file_path)
176
+ change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY)
177
+ return full_path
178
+
179
+
180
+ def get_real_data_path(file_path):
181
+ targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
182
+ pattern = re.compile(r'({})'.format('|'.join(targets)))
183
+ match = pattern.search(file_path)
184
+ if match:
185
+ target_index = match.start()
186
+ target_path = file_path[target_index:]
187
+ return target_path
188
+ else:
189
+ raise DumpException(DumpException.INVALID_PATH_ERROR)
190
+
191
+
192
+ def get_full_data_path(data_path, real_data_path):
193
+ if not data_path:
194
+ return data_path
195
+ full_data_path = os.path.join(real_data_path, data_path)
196
+ return os.path.realpath(full_data_path)
197
+
198
+
199
+ class UtDataProcessor:
200
+ def __init__(self, save_path):
201
+ self.save_path = save_path
202
+ self.index = 0
203
+
204
+ def save_tensors_in_element(self, api_name, element):
205
+ self.index = 0
206
+ self._save_recursive(api_name, element)
207
+
208
+ def _save_recursive(self, api_name, element):
209
+ if isinstance(element, torch.Tensor):
210
+ api_args = api_name + Const.SEP + str(self.index)
211
+ create_directory(self.save_path)
212
+ file_path = os.path.join(self.save_path, f'{api_args}.pt')
213
+ write_pt(file_path, element.contiguous().cpu().detach())
214
+ self.index += 1
215
+ elif element is None or isinstance(element, (bool, int, float, str, slice)):
216
+ self.index += 1
217
+ elif isinstance(element, (list, tuple)):
218
+ for item in element:
219
+ self._save_recursive(api_name, item)
220
+ elif isinstance(element, dict):
221
+ for value in element.values():
222
+ self._save_recursive(api_name, value)
223
+ else:
224
+ self.index += 1
@@ -0,0 +1,216 @@
1
+ # 定义比对算法及比对标准
2
+ import torch
3
+ import numpy as np
4
+
5
+ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import ULP_PARAMETERS
6
+ from msprobe.core.common.const import CompareConst
7
+
8
+
9
+ DEFAULT_THRESHOLD = 1
10
+
11
+
12
+ #cos
13
+ def cosine_sim(bench_output, device_output):
14
+ msg = ""
15
+ n_value = device_output.reshape(-1)
16
+ b_value = bench_output.reshape(-1)
17
+ cos = CompareConst.SPACE
18
+ np.seterr(divide="ignore", invalid="ignore")
19
+ if n_value.shape != b_value.shape:
20
+ msg = f"Shape of device and bench outputs don't match. device: {n_value.shape}, bench: {b_value.shape}."
21
+ return -1, False, msg
22
+ if len(n_value) == 1:
23
+ msg = "All the data in device dump data is scalar. Please refer to other compare algorithms."
24
+ return cos, True, msg
25
+ n_value_max = np.max(np.abs(n_value))
26
+ b_value_max = np.max(np.abs(b_value))
27
+ if n_value_max <= np.finfo(float).eps and b_value_max <= np.finfo(float).eps:
28
+ msg = "All the data in device and bench outputs are zero."
29
+ return cos, True, msg
30
+ elif n_value_max <= np.finfo(float).eps:
31
+ msg = "All the data is zero in device dump data."
32
+ return CompareConst.SPACE, False, msg
33
+ elif b_value_max <= np.finfo(float).eps:
34
+ msg = "All the data is zero in bench dump data."
35
+ return CompareConst.SPACE, False, msg
36
+ else:
37
+ n_value = n_value.astype(float) / n_value_max
38
+ b_value = b_value.astype(float) / b_value_max
39
+ cos = np.dot(n_value, b_value) / (np.linalg.norm(n_value) * np.linalg.norm(b_value))
40
+ if np.isnan(cos):
41
+ msg = "Dump data has NaN when comparing with Cosine Similarity."
42
+ cos = np.clip(cos, -1, 1)
43
+ return cos, cos > 0.99, msg
44
+
45
+
46
+ #rmse
47
+ def get_rmse(abs_err, inf_nan_mask):
48
+ masked_ae = np.where(inf_nan_mask, 0, abs_err)
49
+ mse = np.mean(np.square(masked_ae))
50
+ inf_nan_cnt = np.sum(inf_nan_mask)
51
+ mse = mse * (abs_err.size / (abs_err.size - inf_nan_cnt + 0.0001) + 0.0001)
52
+ rmse = np.sqrt(mse)
53
+ return rmse
54
+
55
+
56
+ #误差均衡性
57
+ def get_error_balance(bench_data, device_data):
58
+ larger_count = np.sum(np.greater(device_data - bench_data.astype(device_data.dtype), 0))
59
+ smaller_count = np.sum(np.less(device_data - bench_data.astype(device_data.dtype), 0))
60
+ total_count = bench_data.size
61
+ error_balance = abs(larger_count - smaller_count) / total_count if total_count > 0 else 0
62
+ return error_balance
63
+
64
+
65
+ #小值域错误占比
66
+ def get_small_value_err_ratio(small_value_mask, abs_err_greater_mask):
67
+ err_mask = np.logical_and(small_value_mask, abs_err_greater_mask)
68
+ small_value_err_num = np.sum(err_mask)
69
+ small_value_num = np.sum(small_value_mask)
70
+ return 0 if small_value_num == 0 else small_value_err_num / small_value_num
71
+
72
+
73
+ def get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask):
74
+ rel_err_tmp = abs_err / abs_bench_with_eps
75
+ rel_err_mask = np.logical_or(small_value_mask, inf_nan_mask)
76
+ rel_err = np.where(rel_err_mask, -1, rel_err_tmp)
77
+ return rel_err
78
+
79
+
80
+ def get_abs_err(bench_data, device_data):
81
+ abs_err = np.abs(device_data - bench_data)
82
+ return abs_err
83
+
84
+
85
+ def get_rel_err_origin(abs_err, b_value):
86
+ rel_err_origin = np.abs(abs_err / b_value)
87
+ return rel_err_origin
88
+
89
+
90
+ def get_max_abs_err(abs_err):
91
+ max_abs_err = abs_err.max()
92
+ bool_result = max_abs_err < 0.001
93
+ return max_abs_err, bool_result
94
+
95
+
96
+ #相对误差最大值
97
+ def get_max_rel_err(rel_err):
98
+ return np.max(rel_err) if np.max(rel_err) >= 0 else 0
99
+
100
+
101
+ #相对误差均值
102
+ def get_mean_rel_err(rel_err):
103
+ non_negative_rel_err = rel_err[rel_err >= 0]
104
+ return np.mean(non_negative_rel_err) if non_negative_rel_err.size > 0 else 0
105
+
106
+
107
+ def get_rel_err_ratio(rel_err, thresholding):
108
+ if np.size(rel_err) == 0:
109
+ ratio = 1
110
+ else:
111
+ ratio = np.divide(np.sum(rel_err < thresholding), np.size(rel_err))
112
+ bool_result = ratio > (1 - thresholding)
113
+ return ratio, bool_result
114
+
115
+
116
+ def get_finite_and_infinite_mask(bench_output, device_output):
117
+ device_finite_mask = np.isfinite(device_output)
118
+ bench_finite_mask = np.isfinite(bench_output.astype(device_output.dtype))
119
+ both_finite_mask = np.logical_and(device_finite_mask, bench_finite_mask)
120
+ inf_nan_mask = np.logical_not(both_finite_mask)
121
+ return both_finite_mask, inf_nan_mask
122
+
123
+
124
+ def get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold):
125
+ small_value_mask = np.less_equal(abs_bench, small_value_threshold)
126
+ small_value_mask = np.logical_and(small_value_mask, both_finite_mask)
127
+ return small_value_mask
128
+
129
+
130
+ def get_abs_bench_with_eps(bench, dtype):
131
+ abs_bench = np.abs(bench)
132
+ eps = np.finfo(bench.dtype).eps if dtype != torch.bfloat16 else CompareConst.BFLOAT16_EPS
133
+ abs_bench_with_eps = abs_bench + eps
134
+ return abs_bench, abs_bench_with_eps
135
+
136
+
137
+ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol):
138
+ '''
139
+ 新精度标准的绝对阈值法中,检查npu和golden输出的inf、nan是否一致
140
+ 输入:
141
+ inf_nan_mask:npu输出和golden输出的inf、nan的mask
142
+ bench_output:golden输出
143
+ device_output:npu输出
144
+ dtype:npu输出的dtype
145
+ 输出:
146
+ inf_nan_err_ratio:npu输出和golden输出的inf、nan不一致的比例
147
+ '''
148
+ abs_gpu, abs_gpu_with_eps = get_abs_bench_with_eps(bench_output, dtype)
149
+ golden_same_dtype = bench_output.astype(device_output.dtype)
150
+ a_min = np.finfo(device_output.dtype).min if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MIN
151
+ a_max = np.finfo(device_output.dtype).max if dtype != torch.bfloat16 else CompareConst.BFLOAT16_MAX
152
+ golden_clip = np.clip(golden_same_dtype, a_min, a_max)
153
+ npu_clip = np.clip(device_output, a_min, a_max)
154
+ clipped_abs_ae = np.abs(npu_clip - golden_clip)
155
+ clipped_re = clipped_abs_ae / abs_gpu_with_eps
156
+ pass_mask = np.less_equal(clipped_re, rtol)
157
+ both_nan_mask = np.logical_and(np.isnan(device_output), np.isnan(golden_clip))
158
+ pass_mask = np.logical_or(pass_mask, both_nan_mask)
159
+ not_pass_mask = np.logical_not(pass_mask)
160
+ not_pass_mask = np.logical_and(not_pass_mask, inf_nan_mask)
161
+
162
+ inf_nan_err_cnt = np.sum(not_pass_mask)
163
+ return 0 if np.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / np.sum(inf_nan_mask)
164
+
165
+
166
+ def check_small_value(abs_err, small_value_mask, small_value_atol):
167
+ '''
168
+ 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值
169
+ 输入:
170
+ rel_err:npu输出和golden输出的相对误差
171
+ normal_value_mask:npu输出和golden输出的正常值mask
172
+ rtol:相对误差的阈值
173
+ 输出:
174
+ rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例
175
+ '''
176
+ greater_mask = np.greater(abs_err, small_value_atol)
177
+ err_mask = np.logical_and(greater_mask, small_value_mask)
178
+ err_cnt = np.sum(err_mask)
179
+ return 0 if np.sum(small_value_mask) == 0 else err_cnt / np.sum(small_value_mask)
180
+
181
+
182
+ def check_norm_value(normal_value_mask, rel_err, rtol):
183
+ '''
184
+ 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值
185
+ 输入:
186
+ abs_err:npu输出和golden输出的绝对误差
187
+ normal_value_mask:npu输出和golden输出的正常值mask
188
+ atol:绝对误差的阈值
189
+ 输出:
190
+ abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例
191
+ '''
192
+ err_mask = np.greater(rel_err, rtol)
193
+ err_mask = np.logical_and(err_mask, normal_value_mask)
194
+ err_cnt = np.sum(err_mask)
195
+ return 0 if np.sum(normal_value_mask) == 0 else err_cnt / np.sum(normal_value_mask)
196
+
197
+
198
+ def get_ulp_err(bench_output, device_output, dtype):
199
+ parameters = ULP_PARAMETERS.get(dtype)
200
+ min_eb = parameters.get('min_eb', DEFAULT_THRESHOLD)[0]
201
+ exponent_num = parameters.get('exponent_num', DEFAULT_THRESHOLD)[0]
202
+ abs_bench = np.abs(bench_output)
203
+ eb = np.where(abs_bench == 0, 0, np.floor(np.log2(abs_bench)))
204
+ eb = np.maximum(eb, min_eb)
205
+
206
+ if dtype == torch.float32:
207
+ ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float64)
208
+ else:
209
+ ulp_err = calc_ulp_err(bench_output, device_output, eb, exponent_num, np.float32)
210
+ ulp_err = np.abs(ulp_err)
211
+ return ulp_err
212
+
213
+
214
+ def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type):
215
+ return (device_output.astype(data_type) - bench_output).astype(data_type) * \
216
+ np.exp2(-eb + exponent_num).astype(data_type)