mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,616 +1,385 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import collections
18
- import os
19
- import re
20
- import shutil
21
- import subprocess
22
- import time
23
- import json
24
- import csv
25
- from datetime import datetime, timezone
26
- from pathlib import Path
27
- import yaml
28
- import numpy as np
29
-
30
- from msprobe.core.common.file_check import FileOpen, FileChecker, change_mode
31
- from msprobe.core.common.const import Const, FileCheckConst, CompareConst
32
- from msprobe.core.common.log import logger
33
-
34
-
35
- device = collections.namedtuple('device', ['type', 'index'])
36
- prefixes = ['api_stack', 'list', 'range', 'acl']
37
-
38
-
39
- class CompareException(Exception):
40
- """
41
- Class for Accuracy Compare Exception
42
- """
43
- NONE_ERROR = 0
44
- INVALID_PATH_ERROR = 1
45
- OPEN_FILE_ERROR = 2
46
- CLOSE_FILE_ERROR = 3
47
- READ_FILE_ERROR = 4
48
- WRITE_FILE_ERROR = 5
49
- INVALID_FILE_ERROR = 6
50
- PERMISSION_ERROR = 7
51
- INDEX_OUT_OF_BOUNDS_ERROR = 8
52
- NO_DUMP_FILE_ERROR = 9
53
- INVALID_DATA_ERROR = 10
54
- INVALID_PARAM_ERROR = 11
55
- INVALID_DUMP_RATIO = 12
56
- INVALID_DUMP_FILE = 13
57
- UNKNOWN_ERROR = 14
58
- INVALID_DUMP_MODE = 15
59
- PARSE_FILE_ERROR = 16
60
- INVALID_COMPARE_MODE = 17
61
- OVER_SIZE_FILE_ERROR = 18
62
- INVALID_SUMMARY_MODE = 19
63
- INVALID_TASK_ERROR = 20
64
-
65
- def __init__(self, code, error_info: str = ""):
66
- super(CompareException, self).__init__()
67
- self.code = code
68
- self.error_info = error_info
69
-
70
- def __str__(self):
71
- return self.error_info
72
-
73
-
74
- class DumpException(CompareException):
75
- pass
76
-
77
-
78
- def make_dump_path_if_not_exists(dump_path):
79
- if not os.path.exists(dump_path):
80
- try:
81
- Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True)
82
- except OSError as ex:
83
- logger.error(
84
- 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex)))
85
- raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
86
- else:
87
- if not os.path.isdir(dump_path):
88
- logger.error('{} already exists and is not a directory.'.format(dump_path))
89
-
90
-
91
- def check_mode_valid(mode, scope=None, api_list=None):
92
- if scope is None:
93
- scope = []
94
- if api_list is None:
95
- api_list = []
96
- if not isinstance(scope, list):
97
- raise ValueError("scope param set invalid, it's must be a list.")
98
- if not isinstance(api_list, list):
99
- raise ValueError("api_list param set invalid, it's must be a list.")
100
- mode_check = {
101
- Const.ALL: lambda: None,
102
- Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
103
- Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
104
- Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
105
- Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
106
- Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
107
- Const.API_STACK: lambda: None,
108
- }
109
- if mode not in Const.DUMP_MODE:
110
- msg = "Current mode '%s' is not supported. Please use the field in %s" % \
111
- (mode, Const.DUMP_MODE)
112
- raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
113
-
114
- if mode_check.get(mode)() is not None:
115
- raise mode_check.get(mode)()
116
-
117
-
118
- def check_switch_valid(switch):
119
- if switch not in ["ON", "OFF"]:
120
- logger.error("Please set switch with 'ON' or 'OFF'.")
121
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
122
-
123
-
124
- def check_dump_mode_valid(dump_mode):
125
- if not isinstance(dump_mode, list):
126
- logger.warning("Please set dump_mode as a list.")
127
- dump_mode = [dump_mode]
128
- if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
129
- raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
130
- if 'input' not in dump_mode and 'output' not in dump_mode:
131
- dump_mode.extend(['input', 'output'])
132
- if 'forward' not in dump_mode and 'backward' not in dump_mode:
133
- dump_mode.extend(['forward', 'backward'])
134
- if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
135
- return ["forward", "backward", "input", "output"]
136
- return dump_mode
137
-
138
-
139
- def check_summary_mode_valid(summary_mode):
140
- if summary_mode not in Const.SUMMARY_MODE:
141
- msg = "The summary_mode is not valid"
142
- raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
143
-
144
-
145
- def check_summary_only_valid(summary_only):
146
- if not isinstance(summary_only, bool):
147
- logger.error("Params summary_only only support True or False.")
148
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
149
- return summary_only
150
-
151
-
152
- def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
153
- if not (isinstance(input_param, dict) and isinstance(output_path, str)):
154
- logger.error("Invalid input parameters")
155
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
156
-
157
- check_file_or_directory_path(input_param.get("npu_json_path"), False)
158
- check_file_or_directory_path(input_param.get("bench_json_path"), False)
159
- check_file_or_directory_path(input_param.get("stack_json_path"), False)
160
- if not summary_compare and not md5_compare:
161
- check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
162
- check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
163
- check_file_or_directory_path(output_path, True)
164
-
165
- with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
166
- FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
167
- FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
168
- check_json_file(input_param, npu_json, bench_json, stack_json)
169
-
170
-
171
-
172
- def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
173
- if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
174
- logger.error("Invalid input parameters which should be only bool type.")
175
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
176
-
177
-
178
- def check_file_or_directory_path(path, isdir=False):
179
- """
180
- Function Description:
181
- check whether the path is valid
182
- Parameter:
183
- path: the path to check
184
- isdir: the path is dir or file
185
- Exception Description:
186
- when invalid data throw exception
187
- """
188
- if isdir:
189
- path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
190
- else:
191
- path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
192
- path_checker.common_check()
193
-
194
-
195
- def is_starts_with(string, prefix_list):
196
- return any(string.startswith(prefix) for prefix in prefix_list)
197
-
198
-
199
- def _check_json(json_file_handle, file_name):
200
- tensor_line = json_file_handle.readline()
201
- if not tensor_line:
202
- logger.error("dump file {} have empty line!".format(file_name))
203
- raise CompareException(CompareException.INVALID_DUMP_FILE)
204
- json_file_handle.seek(0, 0)
205
-
206
-
207
- def check_json_file(input_param, npu_json, bench_json, stack_json):
208
- _check_json(npu_json, input_param.get("npu_json_path"))
209
- _check_json(bench_json, input_param.get("bench_json_path"))
210
- _check_json(stack_json, input_param.get("stack_json_path"))
211
-
212
-
213
- def check_file_size(input_file, max_size):
214
- try:
215
- file_size = os.path.getsize(input_file)
216
- except OSError as os_error:
217
- logger.error('Failed to open "%s". %s' % (input_file, str(os_error)))
218
- raise CompareException(CompareException.INVALID_FILE_ERROR) from os_error
219
- if file_size > max_size:
220
- logger.error('The size (%d) of %s exceeds (%d) bytes, tools not support.'
221
- % (file_size, input_file, max_size))
222
- raise CompareException(CompareException.INVALID_FILE_ERROR)
223
-
224
-
225
- def check_file_not_exists(file_path):
226
- if os.path.exists(file_path) or os.path.islink(file_path):
227
- remove_path(file_path)
228
-
229
-
230
- def check_regex_prefix_format_valid(prefix):
231
- """
232
- validate the format of the regex prefix
233
-
234
- Args:
235
- prefix (str): The prefix string to validate.
236
-
237
- Returns:
238
- no returns
239
-
240
- Raises:
241
- ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
242
- the given pattern Const.REGEX_PREFIX_PATTERN
243
- """
244
- if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
245
- raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
246
- f"is {len(prefix)}")
247
- if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
248
- raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
249
-
250
-
251
- def remove_path(path):
252
- if not os.path.exists(path):
253
- return
254
- try:
255
- if os.path.islink(path) or os.path.isfile(path):
256
- os.remove(path)
257
- else:
258
- shutil.rmtree(path)
259
- except PermissionError as err:
260
- logger.error("Failed to delete {}. Please check the permission.".format(path))
261
- raise CompareException(CompareException.INVALID_PATH_ERROR) from err
262
-
263
-
264
- def move_file(src_path, dst_path):
265
- check_file_or_directory_path(src_path)
266
- check_path_before_create(dst_path)
267
- try:
268
- shutil.move(src_path, dst_path)
269
- except Exception as e:
270
- logger.error(f"move file {src_path} to {dst_path} failed")
271
- raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e
272
- change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY)
273
-
274
-
275
- def get_dump_data_path(dump_dir):
276
- """
277
- Function Description:
278
- traverse directories and obtain the absolute path of dump data
279
- Parameter:
280
- dump_dir: dump data directory
281
- Return Value:
282
- dump data path,file is exist or file is not exist
283
- """
284
- dump_data_path = None
285
- file_is_exist = False
286
-
287
- check_file_or_directory_path(dump_dir, True)
288
- for dir_path, _, files in os.walk(dump_dir):
289
- if len(files) != 0:
290
- dump_data_path = dir_path
291
- file_is_exist = True
292
- break
293
- dump_data_path = dir_path
294
- return dump_data_path, file_is_exist
295
-
296
-
297
- def create_directory(dir_path):
298
- """
299
- Function Description:
300
- creating a directory with specified permissions
301
- Parameter:
302
- dir_path: directory path
303
- Exception Description:
304
- when invalid data throw exception
305
- """
306
- if not os.path.exists(dir_path):
307
- check_path_before_create(dir_path)
308
- try:
309
- os.makedirs(dir_path, mode=0o700)
310
- except OSError as ex:
311
- logger.error(
312
- 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex)))
313
- raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
314
-
315
-
316
- def execute_command(cmd):
317
- """
318
- Function Description:
319
- run the following command
320
- Parameter:
321
- cmd: command
322
- Exception Description:
323
- when invalid command throw exception
324
- """
325
- logger.info('Execute command:%s' % cmd)
326
- process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
327
- while process.poll() is None:
328
- line = process.stdout.readline()
329
- line = line.strip()
330
- if line:
331
- print(line)
332
- if process.returncode != 0:
333
- logger.error('Failed to execute command:%s' % " ".join(cmd))
334
- raise CompareException(CompareException.INVALID_DATA_ERROR)
335
-
336
-
337
- def parse_value_by_comma(value):
338
- """
339
- parse value by comma, like '1,2,4,8'
340
- """
341
- value_list = []
342
- value_str_list = value.split(Const.COMMA)
343
- for value_str in value_str_list:
344
- value_str = value_str.strip()
345
- if value_str.isdigit() or value_str == '-1':
346
- value_list.append(int(value_str))
347
- else:
348
- logger.error("please check your input shape.")
349
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
350
- return value_list
351
-
352
-
353
- def get_data_len_by_shape(shape):
354
- data_len = 1
355
- for item in shape:
356
- if item == -1:
357
- logger.error("please check your input shape, one dim in shape is -1.")
358
- return -1
359
- data_len = data_len * item
360
- return data_len
361
-
362
-
363
- def add_time_as_suffix(name):
364
- return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
365
-
366
-
367
- def add_time_with_xlsx(name):
368
- return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
369
-
370
-
371
- def get_time():
372
- return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
373
-
374
-
375
- def format_value(value):
376
- return float('{:.12f}'.format(value))
377
-
378
-
379
- def check_seed_all(seed, mode):
380
- if isinstance(seed, int):
381
- if seed < 0 or seed > Const.MAX_SEED_VALUE:
382
- logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
383
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
384
- else:
385
- logger.error(f"Seed must be integer.")
386
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
387
- if not isinstance(mode, bool):
388
- logger.error(f"seed_all mode must be bool.")
389
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
390
-
391
-
392
- def get_process_rank(model):
393
- logger.info("Rank id is not provided. Trying to get the rank id of the model.")
394
- try:
395
- local_device = next(model.parameters()).device
396
- except StopIteration:
397
- logger.warning('There is no parameter in the model. Fail to get rank id.')
398
- return 0, False
399
- if local_device.type == 'cpu':
400
- logger.warning("Warning: the debugger is unable to get the rank id. "
401
- "This may cause the dumpped data to be corrupted in the "
402
- "case of distributed training. (You may ignore this if you are using only one card.) "
403
- "Transfer the model to npu or gpu before register_hook() to avoid this warning.")
404
- return 0, False
405
- else:
406
- return local_device.index, True
407
-
408
-
409
- def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
410
- template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
411
- pkl_dir = os.path.dirname(pkl_file_path)
412
- compare_script_path = os.path.join(pkl_dir, "compare_data.py")
413
- is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
414
-
415
- try:
416
- with FileOpen(template_path, 'r') as ftemp, \
417
- os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
418
- code_temp = ftemp.read()
419
- fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
420
- except OSError:
421
- logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
422
-
423
- logger.info(f"Generate compare script successfully which is {compare_script_path}.")
424
-
425
-
426
- def check_file_valid(file_path):
427
- if os.path.islink(file_path):
428
- logger.error('The file path {} is a soft link.'.format(file_path))
429
- raise CompareException(CompareException.INVALID_PATH_ERROR)
430
-
431
- if len(os.path.realpath(file_path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(file_path)) > \
432
- Const.FILE_NAME_LENGTH:
433
- logger.error('The file path length exceeds limit.')
434
- raise CompareException(CompareException.INVALID_PATH_ERROR)
435
-
436
- if not re.match(Const.FILE_PATTERN, os.path.realpath(file_path)):
437
- logger.error('The file path {} contains special characters.'.format(file_path))
438
- raise CompareException(CompareException.INVALID_PATH_ERROR)
439
-
440
- if os.path.isfile(file_path):
441
- file_size = os.path.getsize(file_path)
442
- if file_path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
443
- logger.error('The file {} size is greater than 1GB.'.format(file_path))
444
- raise CompareException(CompareException.INVALID_PATH_ERROR)
445
- if file_path.endswith(Const.NUMPY_SUFFIX) and file_size > Const.TEN_GB:
446
- logger.error('The file {} size is greater than 10GB.'.format(file_path))
447
- raise CompareException(CompareException.INVALID_PATH_ERROR)
448
-
449
-
450
- def check_path_before_create(path):
451
- if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
452
- Const.FILE_NAME_LENGTH:
453
- logger.error('The file path length exceeds limit.')
454
- raise CompareException(CompareException.INVALID_PATH_ERROR)
455
-
456
- if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
457
- logger.error('The file path {} contains special characters.'.format(path))
458
- raise CompareException(CompareException.INVALID_PATH_ERROR)
459
-
460
-
461
- def check_inplace_op(prefix):
462
- if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
463
- return False
464
- match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
465
- op_name = match_op[0] if match_op else None
466
- return op_name in Const.INPLACE_LIST
467
-
468
-
469
- def md5_find(data):
470
- for key_op in data:
471
- for api_info in data[key_op]:
472
- if isinstance(data[key_op][api_info], list):
473
- for data_detail in data[key_op][api_info]:
474
- if data_detail and 'md5' in data_detail:
475
- return True
476
- elif 'md5' in data[key_op][api_info]:
477
- return True
478
- return False
479
-
480
-
481
- def task_dumppath_get(input_param):
482
- npu_path = input_param.get("npu_json_path", None)
483
- bench_path = input_param.get("bench_json_path", None)
484
- if not npu_path or not bench_path:
485
- logger.error(f"Please check the json path is valid.")
486
- raise CompareException(CompareException.INVALID_PATH_ERROR)
487
- with FileOpen(npu_path, 'r') as npu_f:
488
- npu_json_data = json.load(npu_f)
489
- with FileOpen(bench_path, 'r') as bench_f:
490
- bench_json_data = json.load(bench_f)
491
- if npu_json_data['task'] != bench_json_data['task']:
492
- logger.error(f"Please check the dump task is consistent.")
493
- raise CompareException(CompareException.INVALID_TASK_ERROR)
494
- if npu_json_data['task'] == Const.TENSOR:
495
- summary_compare = False
496
- md5_compare = False
497
- elif npu_json_data['task'] == Const.STATISTICS:
498
- md5_compare = md5_find(npu_json_data['data'])
499
- if md5_compare:
500
- summary_compare = False
501
- else:
502
- summary_compare = True
503
- else:
504
- logger.error(f"Compare is not required for overflow_check or free_benchmark.")
505
- raise CompareException(CompareException.INVALID_TASK_ERROR)
506
- input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
507
- input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
508
- return summary_compare, md5_compare
509
-
510
-
511
- def get_header_index(header_name, summary_compare=False):
512
- if summary_compare:
513
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
514
- else:
515
- header = CompareConst.COMPARE_RESULT_HEADER[:]
516
- if header_name not in header:
517
- logger.error(f"{header_name} not in data name")
518
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
519
- return header.index(header_name)
520
-
521
-
522
- def convert_tuple(data):
523
- return data if isinstance(data, tuple) else (data, )
524
-
525
-
526
- def write_csv(data, filepath, mode="a+"):
527
- exist = os.path.exists(filepath)
528
- with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
529
- writer = csv.writer(f)
530
- writer.writerows(data)
531
- if not exist:
532
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
533
-
534
-
535
- def load_npy(filepath):
536
- check_file_or_directory_path(filepath)
537
- try:
538
- npy = np.load(filepath)
539
- except Exception as e:
540
- logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
541
- raise RuntimeError(f"Load numpy file {filepath} failed.") from e
542
- return npy
543
-
544
-
545
- def save_npy(data, filepath):
546
- filepath = os.path.realpath(filepath)
547
- check_path_before_create(filepath)
548
- try:
549
- np.save(filepath, data)
550
- except Exception as e:
551
- logger.error(f"The numpy file failed to save. Please check the path: {filepath}.")
552
- raise RuntimeError(f"Save numpy file {filepath} failed.") from e
553
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
554
-
555
- def save_npy_to_txt(self, data, dst_file='', align=0):
556
- if os.path.exists(dst_file):
557
- self.log.info("Dst file %s exists, will not save new one.", dst_file)
558
- return
559
- shape = data.shape
560
- data = data.flatten()
561
- if align == 0:
562
- align = 1 if len(shape) == 0 else shape[-1]
563
- elif data.size % align != 0:
564
- pad_array = np.zeros((align - data.size % align,))
565
- data = np.append(data, pad_array)
566
- check_path_before_create(dst_file)
567
- try:
568
- np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
569
- except Exception as e:
570
- self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
571
- change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
572
-
573
- def get_json_contents(file_path):
574
- ops = get_file_content_bytes(file_path)
575
- try:
576
- json_obj = json.loads(ops)
577
- except ValueError as error:
578
- logger.error('Failed to load json.')
579
- raise CompareException(CompareException.INVALID_FILE_ERROR) from error
580
- if not isinstance(json_obj, dict):
581
- logger.error('Json file content is not a dictionary!')
582
- raise CompareException(CompareException.INVALID_FILE_ERROR)
583
- return json_obj
584
-
585
-
586
- def get_file_content_bytes(file):
587
- with FileOpen(file, 'rb') as file_handle:
588
- return file_handle.read()
589
-
590
-
591
- def load_yaml(yaml_path):
592
- path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX)
593
- checked_path = path_checker.common_check()
594
- try:
595
- with FileOpen(checked_path, "r") as f:
596
- yaml_data = yaml.safe_load(f)
597
- except Exception as e:
598
- logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.")
599
- raise RuntimeError(f"Load yaml file {checked_path} failed.") from e
600
- return yaml_data
601
-
602
-
603
- def save_workbook(workbook, file_path):
604
- """
605
- 保存工作簿到指定的文件路径
606
- workbook: 要保存的工作簿对象
607
- file_path: 文件保存路径
608
- """
609
- file_path = os.path.realpath(file_path)
610
- check_path_before_create(file_path)
611
- try:
612
- workbook.save(file_path)
613
- except Exception as e:
614
- logger.error(f'Save result file "{os.path.basename(file_path)}" failed')
615
- raise CompareException(CompareException.WRITE_FILE_ERROR) from e
616
- change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import collections
18
+ import os
19
+ import re
20
+ import subprocess
21
+ import time
22
+ import json
23
+ from datetime import datetime, timezone
24
+
25
+ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path)
26
+ from msprobe.core.common.const import Const, CompareConst
27
+ from msprobe.core.common.log import logger
28
+
29
+
30
+ device = collections.namedtuple('device', ['type', 'index'])
31
+ prefixes = ['api_stack', 'list', 'range', 'acl']
32
+
33
+
34
+ class CompareException(Exception):
35
+ """
36
+ Class for Accuracy Compare Exception
37
+ """
38
+ NONE_ERROR = 0
39
+ INVALID_PATH_ERROR = 1
40
+ OPEN_FILE_ERROR = 2
41
+ CLOSE_FILE_ERROR = 3
42
+ READ_FILE_ERROR = 4
43
+ WRITE_FILE_ERROR = 5
44
+ INVALID_FILE_ERROR = 6
45
+ PERMISSION_ERROR = 7
46
+ INDEX_OUT_OF_BOUNDS_ERROR = 8
47
+ NO_DUMP_FILE_ERROR = 9
48
+ INVALID_DATA_ERROR = 10
49
+ INVALID_PARAM_ERROR = 11
50
+ INVALID_DUMP_RATIO = 12
51
+ INVALID_DUMP_FILE = 13
52
+ UNKNOWN_ERROR = 14
53
+ INVALID_DUMP_MODE = 15
54
+ PARSE_FILE_ERROR = 16
55
+ INVALID_COMPARE_MODE = 17
56
+ OVER_SIZE_FILE_ERROR = 18
57
+ INVALID_SUMMARY_MODE = 19
58
+ INVALID_TASK_ERROR = 20
59
+ DETACH_ERROR = 21
60
+
61
+
62
+ def __init__(self, code, error_info: str = ""):
63
+ super(CompareException, self).__init__()
64
+ self.code = code
65
+ self.error_info = error_info
66
+
67
+ def __str__(self):
68
+ return self.error_info
69
+
70
+
71
+ class DumpException(CompareException):
72
+ pass
73
+
74
+
75
+ def check_mode_valid(mode, scope=None, api_list=None):
76
+ if scope is None:
77
+ scope = []
78
+ if api_list is None:
79
+ api_list = []
80
+ if not isinstance(scope, list):
81
+ raise ValueError("scope param set invalid, it's must be a list.")
82
+ if not isinstance(api_list, list):
83
+ raise ValueError("api_list param set invalid, it's must be a list.")
84
+ mode_check = {
85
+ Const.ALL: lambda: None,
86
+ Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
87
+ Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
88
+ Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
89
+ Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
90
+ Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
91
+ Const.API_STACK: lambda: None,
92
+ }
93
+ if mode not in Const.DUMP_MODE:
94
+ msg = "Current mode '%s' is not supported. Please use the field in %s" % \
95
+ (mode, Const.DUMP_MODE)
96
+ raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
97
+
98
+ if mode_check.get(mode)() is not None:
99
+ raise mode_check.get(mode)()
100
+
101
+
102
+ def check_switch_valid(switch):
103
+ if switch not in ["ON", "OFF"]:
104
+ logger.error("Please set switch with 'ON' or 'OFF'.")
105
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
106
+
107
+
108
+ def check_dump_mode_valid(dump_mode):
109
+ if not isinstance(dump_mode, list):
110
+ logger.warning("Please set dump_mode as a list.")
111
+ dump_mode = [dump_mode]
112
+ if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
113
+ raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
114
+ if 'input' not in dump_mode and 'output' not in dump_mode:
115
+ dump_mode.extend(['input', 'output'])
116
+ if 'forward' not in dump_mode and 'backward' not in dump_mode:
117
+ dump_mode.extend(['forward', 'backward'])
118
+ if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
119
+ return ["forward", "backward", "input", "output"]
120
+ return dump_mode
121
+
122
+
123
+ def check_summary_mode_valid(summary_mode):
124
+ if summary_mode not in Const.SUMMARY_MODE:
125
+ msg = "The summary_mode is not valid"
126
+ raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
127
+
128
+
129
+ def check_summary_only_valid(summary_only):
130
+ if not isinstance(summary_only, bool):
131
+ logger.error("Params summary_only only support True or False.")
132
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
133
+ return summary_only
134
+
135
+
136
+ def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
137
+ if not (isinstance(input_param, dict) and isinstance(output_path, str)):
138
+ logger.error("Invalid input parameters")
139
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
140
+
141
+ check_file_or_directory_path(input_param.get("npu_json_path"), False)
142
+ check_file_or_directory_path(input_param.get("bench_json_path"), False)
143
+ check_file_or_directory_path(input_param.get("stack_json_path"), False)
144
+ if not summary_compare and not md5_compare:
145
+ check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
146
+ check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
147
+ check_file_or_directory_path(output_path, True)
148
+
149
+ with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
150
+ FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
151
+ FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
152
+ check_json_file(input_param, npu_json, bench_json, stack_json)
153
+
154
+
155
+
156
+ def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
157
+ if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
158
+ logger.error("Invalid input parameters which should be only bool type.")
159
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
160
+
161
+
162
+ def is_starts_with(string, prefix_list):
163
+ return any(string.startswith(prefix) for prefix in prefix_list)
164
+
165
+
166
+ def _check_json(json_file_handle, file_name):
167
+ tensor_line = json_file_handle.readline()
168
+ if not tensor_line:
169
+ logger.error("dump file {} have empty line!".format(file_name))
170
+ raise CompareException(CompareException.INVALID_DUMP_FILE)
171
+ json_file_handle.seek(0, 0)
172
+
173
+
174
+ def check_json_file(input_param, npu_json, bench_json, stack_json):
175
+ _check_json(npu_json, input_param.get("npu_json_path"))
176
+ _check_json(bench_json, input_param.get("bench_json_path"))
177
+ _check_json(stack_json, input_param.get("stack_json_path"))
178
+
179
+
180
+ def check_regex_prefix_format_valid(prefix):
181
+ """
182
+ validate the format of the regex prefix
183
+
184
+ Args:
185
+ prefix (str): The prefix string to validate.
186
+
187
+ Returns:
188
+ no returns
189
+
190
+ Raises:
191
+ ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
192
+ the given pattern Const.REGEX_PREFIX_PATTERN
193
+ """
194
+ if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
195
+ raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
196
+ f"is {len(prefix)}")
197
+ if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
198
+ raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
199
+
200
+
201
+ def get_dump_data_path(dump_dir):
202
+ """
203
+ Function Description:
204
+ traverse directories and obtain the absolute path of dump data
205
+ Parameter:
206
+ dump_dir: dump data directory
207
+ Return Value:
208
+ dump data path,file is exist or file is not exist
209
+ """
210
+ dump_data_path = None
211
+ file_is_exist = False
212
+
213
+ check_file_or_directory_path(dump_dir, True)
214
+ for dir_path, _, files in os.walk(dump_dir):
215
+ if len(files) != 0:
216
+ dump_data_path = dir_path
217
+ file_is_exist = True
218
+ break
219
+ dump_data_path = dir_path
220
+ return dump_data_path, file_is_exist
221
+
222
+
223
+ def execute_command(cmd):
224
+ """
225
+ Function Description:
226
+ run the following command
227
+ Parameter:
228
+ cmd: command
229
+ Exception Description:
230
+ when invalid command throw exception
231
+ """
232
+ logger.info('Execute command:%s' % cmd)
233
+ process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
234
+ while process.poll() is None:
235
+ line = process.stdout.readline()
236
+ line = line.strip()
237
+ if line:
238
+ print(line)
239
+ if process.returncode != 0:
240
+ logger.error('Failed to execute command:%s' % " ".join(cmd))
241
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
242
+
243
+
244
+ def parse_value_by_comma(value):
245
+ """
246
+ parse value by comma, like '1,2,4,8'
247
+ """
248
+ value_list = []
249
+ value_str_list = value.split(Const.COMMA)
250
+ for value_str in value_str_list:
251
+ value_str = value_str.strip()
252
+ if value_str.isdigit() or value_str == '-1':
253
+ value_list.append(int(value_str))
254
+ else:
255
+ logger.error("please check your input shape.")
256
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
257
+ return value_list
258
+
259
+
260
+ def add_time_as_suffix(name):
261
+ return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
262
+
263
+
264
+ def add_time_with_xlsx(name):
265
+ return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
266
+
267
+
268
+ def get_time():
269
+ return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
270
+
271
+
272
+ def format_value(value):
273
+ return float('{:.12f}'.format(value))
274
+
275
+
276
+ def check_seed_all(seed, mode):
277
+ if isinstance(seed, int):
278
+ if seed < 0 or seed > Const.MAX_SEED_VALUE:
279
+ logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
280
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
281
+ else:
282
+ logger.error(f"Seed must be integer.")
283
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
284
+ if not isinstance(mode, bool):
285
+ logger.error(f"seed_all mode must be bool.")
286
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
287
+
288
+
289
+ def get_process_rank(model):
290
+ logger.info("Rank id is not provided. Trying to get the rank id of the model.")
291
+ try:
292
+ local_device = next(model.parameters()).device
293
+ except StopIteration:
294
+ logger.warning('There is no parameter in the model. Fail to get rank id.')
295
+ return 0, False
296
+ if local_device.type == 'cpu':
297
+ logger.warning("Warning: the debugger is unable to get the rank id. "
298
+ "This may cause the dumpped data to be corrupted in the "
299
+ "case of distributed training. (You may ignore this if you are using only one card.) "
300
+ "Transfer the model to npu or gpu before register_hook() to avoid this warning.")
301
+ return 0, False
302
+ else:
303
+ return local_device.index, True
304
+
305
+
306
+ def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
307
+ template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
308
+ pkl_dir = os.path.dirname(pkl_file_path)
309
+ compare_script_path = os.path.join(pkl_dir, "compare_data.py")
310
+ is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
311
+
312
+ try:
313
+ with FileOpen(template_path, 'r') as ftemp, \
314
+ os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
315
+ code_temp = ftemp.read()
316
+ fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
317
+ except OSError:
318
+ logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
319
+
320
+ logger.info(f"Generate compare script successfully which is {compare_script_path}.")
321
+
322
+
323
+ def check_inplace_op(prefix):
324
+ if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
325
+ return False
326
+ match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
327
+ op_name = match_op[0] if match_op else None
328
+ return op_name in Const.INPLACE_LIST
329
+
330
+
331
+ def md5_find(data):
332
+ for key_op in data:
333
+ for api_info in data[key_op]:
334
+ if isinstance(data[key_op][api_info], list):
335
+ for data_detail in data[key_op][api_info]:
336
+ if data_detail and 'md5' in data_detail:
337
+ return True
338
+ elif 'md5' in data[key_op][api_info]:
339
+ return True
340
+ return False
341
+
342
+
343
+ def task_dumppath_get(input_param):
344
+ npu_path = input_param.get("npu_json_path", None)
345
+ bench_path = input_param.get("bench_json_path", None)
346
+ if not npu_path or not bench_path:
347
+ logger.error(f"Please check the json path is valid.")
348
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
349
+ with FileOpen(npu_path, 'r') as npu_f:
350
+ npu_json_data = json.load(npu_f)
351
+ with FileOpen(bench_path, 'r') as bench_f:
352
+ bench_json_data = json.load(bench_f)
353
+ if npu_json_data['task'] != bench_json_data['task']:
354
+ logger.error(f"Please check the dump task is consistent.")
355
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
356
+ if npu_json_data['task'] == Const.TENSOR:
357
+ summary_compare = False
358
+ md5_compare = False
359
+ elif npu_json_data['task'] == Const.STATISTICS:
360
+ md5_compare = md5_find(npu_json_data['data'])
361
+ if md5_compare:
362
+ summary_compare = False
363
+ else:
364
+ summary_compare = True
365
+ else:
366
+ logger.error(f"Compare is not required for overflow_check or free_benchmark.")
367
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
368
+ input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
369
+ input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
370
+ return summary_compare, md5_compare
371
+
372
+
373
+ def get_header_index(header_name, summary_compare=False):
374
+ if summary_compare:
375
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
376
+ else:
377
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
378
+ if header_name not in header:
379
+ logger.error(f"{header_name} not in data name")
380
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
381
+ return header.index(header_name)
382
+
383
+
384
+ def convert_tuple(data):
385
+ return data if isinstance(data, tuple) else (data, )