mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,616 +1,371 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import collections
18
- import os
19
- import re
20
- import shutil
21
- import subprocess
22
- import time
23
- import json
24
- import csv
25
- from datetime import datetime, timezone
26
- from pathlib import Path
27
- import yaml
28
- import numpy as np
29
-
30
- from msprobe.core.common.file_check import FileOpen, FileChecker, change_mode
31
- from msprobe.core.common.const import Const, FileCheckConst, CompareConst
32
- from msprobe.core.common.log import logger
33
-
34
-
35
- device = collections.namedtuple('device', ['type', 'index'])
36
- prefixes = ['api_stack', 'list', 'range', 'acl']
37
-
38
-
39
- class CompareException(Exception):
40
- """
41
- Class for Accuracy Compare Exception
42
- """
43
- NONE_ERROR = 0
44
- INVALID_PATH_ERROR = 1
45
- OPEN_FILE_ERROR = 2
46
- CLOSE_FILE_ERROR = 3
47
- READ_FILE_ERROR = 4
48
- WRITE_FILE_ERROR = 5
49
- INVALID_FILE_ERROR = 6
50
- PERMISSION_ERROR = 7
51
- INDEX_OUT_OF_BOUNDS_ERROR = 8
52
- NO_DUMP_FILE_ERROR = 9
53
- INVALID_DATA_ERROR = 10
54
- INVALID_PARAM_ERROR = 11
55
- INVALID_DUMP_RATIO = 12
56
- INVALID_DUMP_FILE = 13
57
- UNKNOWN_ERROR = 14
58
- INVALID_DUMP_MODE = 15
59
- PARSE_FILE_ERROR = 16
60
- INVALID_COMPARE_MODE = 17
61
- OVER_SIZE_FILE_ERROR = 18
62
- INVALID_SUMMARY_MODE = 19
63
- INVALID_TASK_ERROR = 20
64
-
65
- def __init__(self, code, error_info: str = ""):
66
- super(CompareException, self).__init__()
67
- self.code = code
68
- self.error_info = error_info
69
-
70
- def __str__(self):
71
- return self.error_info
72
-
73
-
74
- class DumpException(CompareException):
75
- pass
76
-
77
-
78
- def make_dump_path_if_not_exists(dump_path):
79
- if not os.path.exists(dump_path):
80
- try:
81
- Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True)
82
- except OSError as ex:
83
- logger.error(
84
- 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex)))
85
- raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
86
- else:
87
- if not os.path.isdir(dump_path):
88
- logger.error('{} already exists and is not a directory.'.format(dump_path))
89
-
90
-
91
- def check_mode_valid(mode, scope=None, api_list=None):
92
- if scope is None:
93
- scope = []
94
- if api_list is None:
95
- api_list = []
96
- if not isinstance(scope, list):
97
- raise ValueError("scope param set invalid, it's must be a list.")
98
- if not isinstance(api_list, list):
99
- raise ValueError("api_list param set invalid, it's must be a list.")
100
- mode_check = {
101
- Const.ALL: lambda: None,
102
- Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None,
103
- Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None,
104
- Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None,
105
- Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None,
106
- Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None,
107
- Const.API_STACK: lambda: None,
108
- }
109
- if mode not in Const.DUMP_MODE:
110
- msg = "Current mode '%s' is not supported. Please use the field in %s" % \
111
- (mode, Const.DUMP_MODE)
112
- raise CompareException(CompareException.INVALID_DUMP_MODE, msg)
113
-
114
- if mode_check.get(mode)() is not None:
115
- raise mode_check.get(mode)()
116
-
117
-
118
- def check_switch_valid(switch):
119
- if switch not in ["ON", "OFF"]:
120
- logger.error("Please set switch with 'ON' or 'OFF'.")
121
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
122
-
123
-
124
- def check_dump_mode_valid(dump_mode):
125
- if not isinstance(dump_mode, list):
126
- logger.warning("Please set dump_mode as a list.")
127
- dump_mode = [dump_mode]
128
- if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode):
129
- raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.")
130
- if 'input' not in dump_mode and 'output' not in dump_mode:
131
- dump_mode.extend(['input', 'output'])
132
- if 'forward' not in dump_mode and 'backward' not in dump_mode:
133
- dump_mode.extend(['forward', 'backward'])
134
- if 'all' in dump_mode or set(["forward", "backward", "input", "output"]).issubset(set(dump_mode)):
135
- return ["forward", "backward", "input", "output"]
136
- return dump_mode
137
-
138
-
139
- def check_summary_mode_valid(summary_mode):
140
- if summary_mode not in Const.SUMMARY_MODE:
141
- msg = "The summary_mode is not valid"
142
- raise CompareException(CompareException.INVALID_SUMMARY_MODE, msg)
143
-
144
-
145
- def check_summary_only_valid(summary_only):
146
- if not isinstance(summary_only, bool):
147
- logger.error("Params summary_only only support True or False.")
148
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
149
- return summary_only
150
-
151
-
152
- def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
153
- if not (isinstance(input_param, dict) and isinstance(output_path, str)):
154
- logger.error("Invalid input parameters")
155
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
156
-
157
- check_file_or_directory_path(input_param.get("npu_json_path"), False)
158
- check_file_or_directory_path(input_param.get("bench_json_path"), False)
159
- check_file_or_directory_path(input_param.get("stack_json_path"), False)
160
- if not summary_compare and not md5_compare:
161
- check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
162
- check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
163
- check_file_or_directory_path(output_path, True)
164
-
165
- with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
166
- FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
167
- FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
168
- check_json_file(input_param, npu_json, bench_json, stack_json)
169
-
170
-
171
-
172
- def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False):
173
- if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)):
174
- logger.error("Invalid input parameters which should be only bool type.")
175
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
176
-
177
-
178
- def check_file_or_directory_path(path, isdir=False):
179
- """
180
- Function Description:
181
- check whether the path is valid
182
- Parameter:
183
- path: the path to check
184
- isdir: the path is dir or file
185
- Exception Description:
186
- when invalid data throw exception
187
- """
188
- if isdir:
189
- path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE)
190
- else:
191
- path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
192
- path_checker.common_check()
193
-
194
-
195
- def is_starts_with(string, prefix_list):
196
- return any(string.startswith(prefix) for prefix in prefix_list)
197
-
198
-
199
- def _check_json(json_file_handle, file_name):
200
- tensor_line = json_file_handle.readline()
201
- if not tensor_line:
202
- logger.error("dump file {} have empty line!".format(file_name))
203
- raise CompareException(CompareException.INVALID_DUMP_FILE)
204
- json_file_handle.seek(0, 0)
205
-
206
-
207
- def check_json_file(input_param, npu_json, bench_json, stack_json):
208
- _check_json(npu_json, input_param.get("npu_json_path"))
209
- _check_json(bench_json, input_param.get("bench_json_path"))
210
- _check_json(stack_json, input_param.get("stack_json_path"))
211
-
212
-
213
- def check_file_size(input_file, max_size):
214
- try:
215
- file_size = os.path.getsize(input_file)
216
- except OSError as os_error:
217
- logger.error('Failed to open "%s". %s' % (input_file, str(os_error)))
218
- raise CompareException(CompareException.INVALID_FILE_ERROR) from os_error
219
- if file_size > max_size:
220
- logger.error('The size (%d) of %s exceeds (%d) bytes, tools not support.'
221
- % (file_size, input_file, max_size))
222
- raise CompareException(CompareException.INVALID_FILE_ERROR)
223
-
224
-
225
- def check_file_not_exists(file_path):
226
- if os.path.exists(file_path) or os.path.islink(file_path):
227
- remove_path(file_path)
228
-
229
-
230
- def check_regex_prefix_format_valid(prefix):
231
- """
232
- validate the format of the regex prefix
233
-
234
- Args:
235
- prefix (str): The prefix string to validate.
236
-
237
- Returns:
238
- no returns
239
-
240
- Raises:
241
- ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
242
- the given pattern Const.REGEX_PREFIX_PATTERN
243
- """
244
- if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
245
- raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
246
- f"is {len(prefix)}")
247
- if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
248
- raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
249
-
250
-
251
- def remove_path(path):
252
- if not os.path.exists(path):
253
- return
254
- try:
255
- if os.path.islink(path) or os.path.isfile(path):
256
- os.remove(path)
257
- else:
258
- shutil.rmtree(path)
259
- except PermissionError as err:
260
- logger.error("Failed to delete {}. Please check the permission.".format(path))
261
- raise CompareException(CompareException.INVALID_PATH_ERROR) from err
262
-
263
-
264
- def move_file(src_path, dst_path):
265
- check_file_or_directory_path(src_path)
266
- check_path_before_create(dst_path)
267
- try:
268
- shutil.move(src_path, dst_path)
269
- except Exception as e:
270
- logger.error(f"move file {src_path} to {dst_path} failed")
271
- raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e
272
- change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY)
273
-
274
-
275
- def get_dump_data_path(dump_dir):
276
- """
277
- Function Description:
278
- traverse directories and obtain the absolute path of dump data
279
- Parameter:
280
- dump_dir: dump data directory
281
- Return Value:
282
- dump data path,file is exist or file is not exist
283
- """
284
- dump_data_path = None
285
- file_is_exist = False
286
-
287
- check_file_or_directory_path(dump_dir, True)
288
- for dir_path, _, files in os.walk(dump_dir):
289
- if len(files) != 0:
290
- dump_data_path = dir_path
291
- file_is_exist = True
292
- break
293
- dump_data_path = dir_path
294
- return dump_data_path, file_is_exist
295
-
296
-
297
- def create_directory(dir_path):
298
- """
299
- Function Description:
300
- creating a directory with specified permissions
301
- Parameter:
302
- dir_path: directory path
303
- Exception Description:
304
- when invalid data throw exception
305
- """
306
- if not os.path.exists(dir_path):
307
- check_path_before_create(dir_path)
308
- try:
309
- os.makedirs(dir_path, mode=0o700)
310
- except OSError as ex:
311
- logger.error(
312
- 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex)))
313
- raise CompareException(CompareException.INVALID_PATH_ERROR) from ex
314
-
315
-
316
- def execute_command(cmd):
317
- """
318
- Function Description:
319
- run the following command
320
- Parameter:
321
- cmd: command
322
- Exception Description:
323
- when invalid command throw exception
324
- """
325
- logger.info('Execute command:%s' % cmd)
326
- process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
327
- while process.poll() is None:
328
- line = process.stdout.readline()
329
- line = line.strip()
330
- if line:
331
- print(line)
332
- if process.returncode != 0:
333
- logger.error('Failed to execute command:%s' % " ".join(cmd))
334
- raise CompareException(CompareException.INVALID_DATA_ERROR)
335
-
336
-
337
- def parse_value_by_comma(value):
338
- """
339
- parse value by comma, like '1,2,4,8'
340
- """
341
- value_list = []
342
- value_str_list = value.split(Const.COMMA)
343
- for value_str in value_str_list:
344
- value_str = value_str.strip()
345
- if value_str.isdigit() or value_str == '-1':
346
- value_list.append(int(value_str))
347
- else:
348
- logger.error("please check your input shape.")
349
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
350
- return value_list
351
-
352
-
353
- def get_data_len_by_shape(shape):
354
- data_len = 1
355
- for item in shape:
356
- if item == -1:
357
- logger.error("please check your input shape, one dim in shape is -1.")
358
- return -1
359
- data_len = data_len * item
360
- return data_len
361
-
362
-
363
- def add_time_as_suffix(name):
364
- return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
365
-
366
-
367
- def add_time_with_xlsx(name):
368
- return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
369
-
370
-
371
- def get_time():
372
- return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
373
-
374
-
375
- def format_value(value):
376
- return float('{:.12f}'.format(value))
377
-
378
-
379
- def check_seed_all(seed, mode):
380
- if isinstance(seed, int):
381
- if seed < 0 or seed > Const.MAX_SEED_VALUE:
382
- logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
383
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
384
- else:
385
- logger.error(f"Seed must be integer.")
386
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
387
- if not isinstance(mode, bool):
388
- logger.error(f"seed_all mode must be bool.")
389
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
390
-
391
-
392
- def get_process_rank(model):
393
- logger.info("Rank id is not provided. Trying to get the rank id of the model.")
394
- try:
395
- local_device = next(model.parameters()).device
396
- except StopIteration:
397
- logger.warning('There is no parameter in the model. Fail to get rank id.')
398
- return 0, False
399
- if local_device.type == 'cpu':
400
- logger.warning("Warning: the debugger is unable to get the rank id. "
401
- "This may cause the dumpped data to be corrupted in the "
402
- "case of distributed training. (You may ignore this if you are using only one card.) "
403
- "Transfer the model to npu or gpu before register_hook() to avoid this warning.")
404
- return 0, False
405
- else:
406
- return local_device.index, True
407
-
408
-
409
- def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode):
410
- template_path = os.path.join(os.path.dirname(__file__), "compare_script.template")
411
- pkl_dir = os.path.dirname(pkl_file_path)
412
- compare_script_path = os.path.join(pkl_dir, "compare_data.py")
413
- is_api_stack = "True" if dump_switch_mode == Const.API_STACK else "False"
414
-
415
- try:
416
- with FileOpen(template_path, 'r') as ftemp, \
417
- os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout:
418
- code_temp = ftemp.read()
419
- fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack))
420
- except OSError:
421
- logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.")
422
-
423
- logger.info(f"Generate compare script successfully which is {compare_script_path}.")
424
-
425
-
426
- def check_file_valid(file_path):
427
- if os.path.islink(file_path):
428
- logger.error('The file path {} is a soft link.'.format(file_path))
429
- raise CompareException(CompareException.INVALID_PATH_ERROR)
430
-
431
- if len(os.path.realpath(file_path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(file_path)) > \
432
- Const.FILE_NAME_LENGTH:
433
- logger.error('The file path length exceeds limit.')
434
- raise CompareException(CompareException.INVALID_PATH_ERROR)
435
-
436
- if not re.match(Const.FILE_PATTERN, os.path.realpath(file_path)):
437
- logger.error('The file path {} contains special characters.'.format(file_path))
438
- raise CompareException(CompareException.INVALID_PATH_ERROR)
439
-
440
- if os.path.isfile(file_path):
441
- file_size = os.path.getsize(file_path)
442
- if file_path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB:
443
- logger.error('The file {} size is greater than 1GB.'.format(file_path))
444
- raise CompareException(CompareException.INVALID_PATH_ERROR)
445
- if file_path.endswith(Const.NUMPY_SUFFIX) and file_size > Const.TEN_GB:
446
- logger.error('The file {} size is greater than 10GB.'.format(file_path))
447
- raise CompareException(CompareException.INVALID_PATH_ERROR)
448
-
449
-
450
- def check_path_before_create(path):
451
- if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \
452
- Const.FILE_NAME_LENGTH:
453
- logger.error('The file path length exceeds limit.')
454
- raise CompareException(CompareException.INVALID_PATH_ERROR)
455
-
456
- if not re.match(Const.FILE_PATTERN, os.path.realpath(path)):
457
- logger.error('The file path {} contains special characters.'.format(path))
458
- raise CompareException(CompareException.INVALID_PATH_ERROR)
459
-
460
-
461
- def check_inplace_op(prefix):
462
- if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH:
463
- return False
464
- match_op = re.findall(r"Distributed\.(.+?)\.\d", prefix)
465
- op_name = match_op[0] if match_op else None
466
- return op_name in Const.INPLACE_LIST
467
-
468
-
469
- def md5_find(data):
470
- for key_op in data:
471
- for api_info in data[key_op]:
472
- if isinstance(data[key_op][api_info], list):
473
- for data_detail in data[key_op][api_info]:
474
- if data_detail and 'md5' in data_detail:
475
- return True
476
- elif 'md5' in data[key_op][api_info]:
477
- return True
478
- return False
479
-
480
-
481
- def task_dumppath_get(input_param):
482
- npu_path = input_param.get("npu_json_path", None)
483
- bench_path = input_param.get("bench_json_path", None)
484
- if not npu_path or not bench_path:
485
- logger.error(f"Please check the json path is valid.")
486
- raise CompareException(CompareException.INVALID_PATH_ERROR)
487
- with FileOpen(npu_path, 'r') as npu_f:
488
- npu_json_data = json.load(npu_f)
489
- with FileOpen(bench_path, 'r') as bench_f:
490
- bench_json_data = json.load(bench_f)
491
- if npu_json_data['task'] != bench_json_data['task']:
492
- logger.error(f"Please check the dump task is consistent.")
493
- raise CompareException(CompareException.INVALID_TASK_ERROR)
494
- if npu_json_data['task'] == Const.TENSOR:
495
- summary_compare = False
496
- md5_compare = False
497
- elif npu_json_data['task'] == Const.STATISTICS:
498
- md5_compare = md5_find(npu_json_data['data'])
499
- if md5_compare:
500
- summary_compare = False
501
- else:
502
- summary_compare = True
503
- else:
504
- logger.error(f"Compare is not required for overflow_check or free_benchmark.")
505
- raise CompareException(CompareException.INVALID_TASK_ERROR)
506
- input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
507
- input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
508
- return summary_compare, md5_compare
509
-
510
-
511
- def get_header_index(header_name, summary_compare=False):
512
- if summary_compare:
513
- header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
514
- else:
515
- header = CompareConst.COMPARE_RESULT_HEADER[:]
516
- if header_name not in header:
517
- logger.error(f"{header_name} not in data name")
518
- raise CompareException(CompareException.INVALID_PARAM_ERROR)
519
- return header.index(header_name)
520
-
521
-
522
- def convert_tuple(data):
523
- return data if isinstance(data, tuple) else (data, )
524
-
525
-
526
- def write_csv(data, filepath, mode="a+"):
527
- exist = os.path.exists(filepath)
528
- with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
529
- writer = csv.writer(f)
530
- writer.writerows(data)
531
- if not exist:
532
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
533
-
534
-
535
- def load_npy(filepath):
536
- check_file_or_directory_path(filepath)
537
- try:
538
- npy = np.load(filepath)
539
- except Exception as e:
540
- logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
541
- raise RuntimeError(f"Load numpy file {filepath} failed.") from e
542
- return npy
543
-
544
-
545
- def save_npy(data, filepath):
546
- filepath = os.path.realpath(filepath)
547
- check_path_before_create(filepath)
548
- try:
549
- np.save(filepath, data)
550
- except Exception as e:
551
- logger.error(f"The numpy file failed to save. Please check the path: {filepath}.")
552
- raise RuntimeError(f"Save numpy file {filepath} failed.") from e
553
- change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
554
-
555
- def save_npy_to_txt(self, data, dst_file='', align=0):
556
- if os.path.exists(dst_file):
557
- self.log.info("Dst file %s exists, will not save new one.", dst_file)
558
- return
559
- shape = data.shape
560
- data = data.flatten()
561
- if align == 0:
562
- align = 1 if len(shape) == 0 else shape[-1]
563
- elif data.size % align != 0:
564
- pad_array = np.zeros((align - data.size % align,))
565
- data = np.append(data, pad_array)
566
- check_path_before_create(dst_file)
567
- try:
568
- np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
569
- except Exception as e:
570
- self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
571
- change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
572
-
573
- def get_json_contents(file_path):
574
- ops = get_file_content_bytes(file_path)
575
- try:
576
- json_obj = json.loads(ops)
577
- except ValueError as error:
578
- logger.error('Failed to load json.')
579
- raise CompareException(CompareException.INVALID_FILE_ERROR) from error
580
- if not isinstance(json_obj, dict):
581
- logger.error('Json file content is not a dictionary!')
582
- raise CompareException(CompareException.INVALID_FILE_ERROR)
583
- return json_obj
584
-
585
-
586
- def get_file_content_bytes(file):
587
- with FileOpen(file, 'rb') as file_handle:
588
- return file_handle.read()
589
-
590
-
591
- def load_yaml(yaml_path):
592
- path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX)
593
- checked_path = path_checker.common_check()
594
- try:
595
- with FileOpen(checked_path, "r") as f:
596
- yaml_data = yaml.safe_load(f)
597
- except Exception as e:
598
- logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.")
599
- raise RuntimeError(f"Load yaml file {checked_path} failed.") from e
600
- return yaml_data
601
-
602
-
603
- def save_workbook(workbook, file_path):
604
- """
605
- 保存工作簿到指定的文件路径
606
- workbook: 要保存的工作簿对象
607
- file_path: 文件保存路径
608
- """
609
- file_path = os.path.realpath(file_path)
610
- check_path_before_create(file_path)
611
- try:
612
- workbook.save(file_path)
613
- except Exception as e:
614
- logger.error(f'Save result file "{os.path.basename(file_path)}" failed')
615
- raise CompareException(CompareException.WRITE_FILE_ERROR) from e
616
- change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ import collections
18
+ import os
19
+ import re
20
+ import subprocess
21
+ import time
22
+ import json
23
+ from datetime import datetime, timezone
24
+
25
+ from msprobe.core.common.file_utils import (FileOpen, check_file_or_directory_path, load_json)
26
+ from msprobe.core.common.const import Const, CompareConst
27
+ from msprobe.core.common.log import logger
28
+ from msprobe.core.common.exceptions import MsprobeException
29
+
30
+
31
+ device = collections.namedtuple('device', ['type', 'index'])
32
+ prefixes = ['api_stack', 'list', 'range', 'acl']
33
+
34
+
35
+ class MsprobeBaseException(Exception):
36
+ """
37
+ Base class for all custom exceptions.
38
+ """
39
+ # 所有的错误代码
40
+ NONE_ERROR = 0
41
+ INVALID_PATH_ERROR = 1
42
+ OPEN_FILE_ERROR = 2
43
+ CLOSE_FILE_ERROR = 3
44
+ READ_FILE_ERROR = 4
45
+ WRITE_FILE_ERROR = 5
46
+ INVALID_FILE_ERROR = 6
47
+ PERMISSION_ERROR = 7
48
+ INDEX_OUT_OF_BOUNDS_ERROR = 8
49
+ NO_DUMP_FILE_ERROR = 9
50
+ INVALID_DATA_ERROR = 10
51
+ INVALID_PARAM_ERROR = 11
52
+ INVALID_DUMP_RATIO = 12
53
+ INVALID_DUMP_FILE = 13
54
+ UNKNOWN_ERROR = 14
55
+ INVALID_DUMP_MODE = 15
56
+ PARSE_FILE_ERROR = 16
57
+ INVALID_COMPARE_MODE = 17
58
+ OVER_SIZE_FILE_ERROR = 18
59
+ INVALID_SUMMARY_MODE = 19
60
+ INVALID_TASK_ERROR = 20
61
+ DETACH_ERROR = 21
62
+ INVALID_OBJECT_TYPE_ERROR = 22
63
+ INVALID_CHAR_ERROR = 23
64
+ RECURSION_LIMIT_ERROR = 24
65
+ INVALID_ATTRIBUTE_ERROR = 25
66
+ OUTPUT_HOOK_ERROR = 26
67
+ INPUT_HOOK_ERROR = 27
68
+ FUNCTION_CALL_ERROR = 28
69
+ FORWARD_DATA_COLLECTION_ERROR = 29
70
+ BACKWARD_DATA_COLLECTION_ERROR = 30
71
+
72
+ def __init__(self, code, error_info: str = ""):
73
+ super(MsprobeBaseException, self).__init__()
74
+ self.code = code
75
+ self.error_info = error_info
76
+
77
+ def __str__(self):
78
+ return self.error_info
79
+
80
+
81
+ class CompareException(MsprobeBaseException):
82
+ """
83
+ Class for Accuracy Compare Exception
84
+ """
85
+
86
+ def __init__(self, code, error_info: str = ""):
87
+ super(CompareException, self).__init__(code, error_info)
88
+
89
+
90
+ class DumpException(MsprobeBaseException):
91
+ """
92
+ Class for Dump Exception
93
+ """
94
+
95
+ def __init__(self, code, error_info: str = ""):
96
+ super(DumpException, self).__init__(code, error_info)
97
+
98
+ def __str__(self):
99
+ return f"Dump Error Code {self.code}: {self.error_info}"
100
+
101
+
102
+ def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False):
103
+ if not isinstance(input_param, dict):
104
+ logger.error(f"Invalid input parameter 'input_param', the expected type dict but got {type(input_param)}.")
105
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
106
+ if not isinstance(output_path, str):
107
+ logger.error(f"Invalid input parameter 'output_path', the expected type str but got {type(output_path)}.")
108
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
109
+
110
+ check_file_or_directory_path(input_param.get("npu_json_path"), False)
111
+ check_file_or_directory_path(input_param.get("bench_json_path"), False)
112
+ check_file_or_directory_path(input_param.get("stack_json_path"), False)
113
+ if not summary_compare and not md5_compare:
114
+ check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True)
115
+ check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True)
116
+ check_file_or_directory_path(output_path, True)
117
+
118
+ with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \
119
+ FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \
120
+ FileOpen(input_param.get("stack_json_path"), "r") as stack_json:
121
+ check_json_file(input_param, npu_json, bench_json, stack_json)
122
+
123
+
124
+ def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, is_print_compare_log=True):
125
+ arg_list = [stack_mode, auto_analyze, fuzzy_match, is_print_compare_log]
126
+ for arg in arg_list:
127
+ if not isinstance(arg, bool):
128
+ logger.error(f"Invalid input parameter, {arg} which should be only bool type.")
129
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
130
+
131
+
132
+ def _check_json(json_file_handle, file_name):
133
+ tensor_line = json_file_handle.readline()
134
+ if not tensor_line:
135
+ logger.error("dump file {} have empty line!".format(file_name))
136
+ raise CompareException(CompareException.INVALID_DUMP_FILE)
137
+ json_file_handle.seek(0, 0)
138
+
139
+
140
+ def check_json_file(input_param, npu_json, bench_json, stack_json):
141
+ _check_json(npu_json, input_param.get("npu_json_path"))
142
+ _check_json(bench_json, input_param.get("bench_json_path"))
143
+ _check_json(stack_json, input_param.get("stack_json_path"))
144
+
145
+
146
+ def check_regex_prefix_format_valid(prefix):
147
+ """
148
+ validate the format of the regex prefix
149
+
150
+ Args:
151
+ prefix (str): The prefix string to validate.
152
+
153
+ Returns:
154
+ no returns
155
+
156
+ Raises:
157
+ ValueError: if the prefix length exceeds Const.REGEX_PREFIX_MAX_LENGTH characters or the prefix do not match
158
+ the given pattern Const.REGEX_PREFIX_PATTERN
159
+ """
160
+ if len(prefix) > Const.REGEX_PREFIX_MAX_LENGTH:
161
+ raise ValueError(f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, while current length "
162
+ f"is {len(prefix)}")
163
+ if not re.match(Const.REGEX_PREFIX_PATTERN, prefix):
164
+ raise ValueError(f"prefix contains invalid characters, prefix pattern {Const.REGEX_PREFIX_PATTERN}")
165
+
166
+
167
+ def execute_command(cmd):
168
+ """
169
+ Function Description:
170
+ run the following command
171
+ Parameter:
172
+ cmd: command
173
+ Exception Description:
174
+ when invalid command throw exception
175
+ """
176
+ logger.info('Execute command:%s' % cmd)
177
+ process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
178
+ while process.poll() is None:
179
+ line = process.stdout.readline()
180
+ line = line.strip()
181
+ if line:
182
+ print(line)
183
+ if process.returncode != 0:
184
+ logger.error('Failed to execute command:%s' % " ".join(cmd))
185
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
186
+
187
+
188
+ def add_time_as_suffix(name):
189
+ return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
190
+
191
+
192
+ def add_time_with_xlsx(name):
193
+ return '{}_{}.xlsx'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
194
+
195
+
196
+ def add_time_with_yaml(name):
197
+ return '{}_{}.yaml'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())))
198
+
199
+
200
+ def get_time():
201
+ return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
202
+
203
+
204
+ def format_value(value):
205
+ return float('{:.12f}'.format(value))
206
+
207
+
208
+ def md5_find(data):
209
+ for key_op in data:
210
+ for api_info in data[key_op]:
211
+ if isinstance(data[key_op][api_info], list):
212
+ for data_detail in data[key_op][api_info]:
213
+ if data_detail and 'md5' in data_detail:
214
+ return True
215
+ elif 'md5' in data[key_op][api_info]:
216
+ return True
217
+ return False
218
+
219
+
220
+ def struct_json_get(input_param, framework):
221
+ if framework == Const.PT_FRAMEWORK:
222
+ prefix = "bench"
223
+ elif framework == Const.MS_FRAMEWORK:
224
+ prefix = "npu"
225
+ else:
226
+ logger.error("Error framework found.")
227
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
228
+
229
+ frame_json_path = input_param.get(f"{prefix}_json_path", None)
230
+ if not frame_json_path:
231
+ logger.error(f"Please check the json path is valid.")
232
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
233
+ directory = os.path.dirname(frame_json_path)
234
+ check_file_or_directory_path(directory, True)
235
+ stack_json = os.path.join(directory, "stack.json")
236
+ construct_json = os.path.join(directory, "construct.json")
237
+
238
+ stack = load_json(stack_json)
239
+ construct = load_json(construct_json)
240
+ return stack, construct
241
+
242
+
243
+ def task_dumppath_get(input_param):
244
+ npu_path = input_param.get("npu_json_path", None)
245
+ bench_path = input_param.get("bench_json_path", None)
246
+ if not npu_path or not bench_path:
247
+ logger.error(f"Please check the json path is valid.")
248
+ raise CompareException(CompareException.INVALID_PATH_ERROR)
249
+ with FileOpen(npu_path, 'r') as npu_f:
250
+ npu_json_data = json.load(npu_f)
251
+ with FileOpen(bench_path, 'r') as bench_f:
252
+ bench_json_data = json.load(bench_f)
253
+ if npu_json_data['task'] != bench_json_data['task']:
254
+ logger.error(f"Please check the dump task is consistent.")
255
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
256
+ if npu_json_data['task'] == Const.TENSOR:
257
+ summary_compare = False
258
+ md5_compare = False
259
+ elif npu_json_data['task'] == Const.STATISTICS:
260
+ md5_compare = md5_find(npu_json_data['data'])
261
+ if md5_compare:
262
+ summary_compare = False
263
+ else:
264
+ summary_compare = True
265
+ else:
266
+ logger.error(f"Compare is not required for overflow_check or free_benchmark.")
267
+ raise CompareException(CompareException.INVALID_TASK_ERROR)
268
+ input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA)
269
+ input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA)
270
+ return summary_compare, md5_compare
271
+
272
+
273
+ def get_header_index(header_name, summary_compare=False):
274
+ if summary_compare:
275
+ header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:]
276
+ else:
277
+ header = CompareConst.COMPARE_RESULT_HEADER[:]
278
+ if header_name not in header:
279
+ logger.error(f"{header_name} not in data name")
280
+ raise CompareException(CompareException.INVALID_PARAM_ERROR)
281
+ return header.index(header_name)
282
+
283
+
284
+ def convert_tuple(data):
285
+ return data if isinstance(data, tuple) else (data, )
286
+
287
+
288
+ def check_op_str_pattern_valid(string, op_name=None, stack=False):
289
+ if isinstance(string, str) and is_invalid_pattern(string):
290
+ if stack:
291
+ message = f"stack info of {op_name} contains special characters, please check!"
292
+ elif not op_name:
293
+ message = f"api name contains special characters, please check!"
294
+ else:
295
+ message = f"data info of {op_name} contains special characters, please check!"
296
+ logger.error(message)
297
+ raise CompareException(CompareException.INVALID_CHAR_ERROR)
298
+
299
+
300
+ def is_invalid_pattern(string):
301
+ pattern = Const.STRING_BLACKLIST
302
+ return re.search(pattern, string)
303
+
304
+
305
+ def print_tools_ends_info():
306
+ total_len = len(Const.TOOL_ENDS_SUCCESSFULLY) + Const.FILL_CHAR_NUMS
307
+ logger.info('*' * total_len)
308
+ logger.info(f"*{Const.TOOL_ENDS_SUCCESSFULLY.center(total_len - 2)}*")
309
+ logger.info('*' * total_len)
310
+
311
+
312
+ def get_step_or_rank_from_string(step_or_rank, obj):
313
+ splited = step_or_rank.split(Const.HYPHEN)
314
+ if len(splited) == 2:
315
+ try:
316
+ borderlines = int(splited[0]), int(splited[1])
317
+ except (ValueError, IndexError) as e:
318
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
319
+ "The hyphen(-) must start and end with decimal numbers.") from e
320
+ else:
321
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
322
+ f'The string parameter for {obj} only supports formats like "3-5". Now string parameter for {obj} is "{step_or_rank}".')
323
+ if all(Const.STEP_RANK_MAXIMUM_RANGE[0] <= b <= Const.STEP_RANK_MAXIMUM_RANGE[1] for b in borderlines):
324
+ if borderlines[0] <= borderlines[1]:
325
+ continual_step_or_rank = list(range(borderlines[0], borderlines[1] + 1))
326
+ else:
327
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
328
+ f'For the hyphen(-) in {obj}, the left boundary ({borderlines[0]}) cannot be greater than the right boundary ({borderlines[1]}).')
329
+ else:
330
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
331
+ f"The boundaries must fall within the range of [{Const.STEP_RANK_MAXIMUM_RANGE[0]}, {Const.STEP_RANK_MAXIMUM_RANGE[1]}].")
332
+ return continual_step_or_rank
333
+
334
+
335
+ def get_real_step_or_rank(step_or_rank_input, obj):
336
+ if obj not in [Const.STEP, Const.RANK]:
337
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
338
+ f"Only support parsing {[Const.STEP, Const.RANK]}, the current parsing object is {obj}.")
339
+ if step_or_rank_input is None:
340
+ return []
341
+ if not isinstance(step_or_rank_input, list):
342
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"{obj} is invalid, it should be a list")
343
+ real_step_or_rank = []
344
+ for element in step_or_rank_input:
345
+ if not isinstance(element, (int, str)):
346
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
347
+ f"{obj} element {element} must be an integer or string.")
348
+ if isinstance(element, int) and element < 0:
349
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
350
+ f"Each element of {obj} must be non-negative, currently it is {element}.")
351
+ if isinstance(element, int) and Const.STEP_RANK_MAXIMUM_RANGE[0] <= element <= Const.STEP_RANK_MAXIMUM_RANGE[1]:
352
+ real_step_or_rank.append(element)
353
+ elif isinstance(element, str) and Const.HYPHEN in element:
354
+ continual_step_or_rank = get_step_or_rank_from_string(element, obj)
355
+ real_step_or_rank.extend(continual_step_or_rank)
356
+ real_step_or_rank = list(set(real_step_or_rank))
357
+ real_step_or_rank.sort()
358
+ return real_step_or_rank
359
+
360
+
361
+ def check_seed_all(seed, mode):
362
+ if isinstance(seed, int):
363
+ if seed < 0 or seed > Const.MAX_SEED_VALUE:
364
+ logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
365
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
366
+ else:
367
+ logger.error("Seed must be integer.")
368
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
369
+ if not isinstance(mode, bool):
370
+ logger.error("seed_all mode must be bool.")
371
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)