mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,165 +1,246 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """
17
- import os
18
- import re
19
-
20
- import torch
21
-
22
- try:
23
- import torch_npu
24
- except ImportError:
25
- IS_GPU = True
26
- else:
27
- IS_GPU = False
28
-
29
- from msprobe.pytorch.common.log import logger
30
- from msprobe.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory
31
- from msprobe.core.common.const import Const, FileCheckConst
32
- from msprobe.core.common.utils import CompareException
33
-
34
-
35
- class DumpException(CompareException):
36
- pass
37
-
38
-
39
- def check_object_type(check_object, allow_type):
40
- """
41
- Function Description:
42
- Check if the object belongs to a certain data type
43
- Parameter:
44
- check_object: the object to be checked
45
- allow_type: legal data type
46
- Exception Description:
47
- when invalid data throw exception
48
- """
49
- if not isinstance(check_object, allow_type):
50
- logger.error(f"{check_object} not of {allow_type} type")
51
- raise CompareException(CompareException.INVALID_DATA_ERROR)
52
-
53
-
54
- class SoftlinkCheckException(Exception):
55
- pass
56
-
57
-
58
- def check_need_convert(api_name):
59
- convert_type = None
60
- for key, value in Const.CONVERT_API.items():
61
- if api_name not in value:
62
- continue
63
- else:
64
- convert_type = key
65
- return convert_type
66
-
67
-
68
- def api_info_preprocess(api_name, api_info_dict):
69
- """
70
- Function Description:
71
- Preprocesses the API information.
72
- Parameter:
73
- api_name: Name of the API.
74
- api_info_dict: argument of the API.
75
- Return api_info_dict:
76
- convert_type: Type of conversion.
77
- api_info_dict: Processed argument of the API.
78
- """
79
- convert_type = check_need_convert(api_name)
80
- if api_name == 'cross_entropy':
81
- api_info_dict = cross_entropy_process(api_info_dict)
82
- return convert_type, api_info_dict
83
-
84
-
85
- def cross_entropy_process(api_info_dict):
86
- """
87
- Function Description:
88
- Preprocesses the cross_entropy API information.
89
- Parameter:
90
- api_info_dict: argument of the API.
91
- Return api_info_dict:
92
- api_info_dict: Processed argument of the API.
93
- """
94
- if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]:
95
- if api_info_dict['args'][1]['Min'] <= 0:
96
- # The second argument in cross_entropy should be -100 or not less than 0
97
- api_info_dict['args'][1]['Min'] = 0
98
- return api_info_dict
99
-
100
-
101
- def initialize_save_path(save_path, dir_name):
102
- data_path = os.path.join(save_path, dir_name)
103
- if os.path.exists(data_path):
104
- logger.warning(f"{data_path} already exists, it will be overwritten")
105
- else:
106
- os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY)
107
- data_path_checker = FileChecker(data_path, FileCheckConst.DIR)
108
- data_path_checker.common_check()
109
- return data_path
110
-
111
-
112
- def write_pt(file_path, tensor):
113
- if os.path.exists(file_path):
114
- raise ValueError(f"File {file_path} already exists")
115
- torch.save(tensor, file_path)
116
- full_path = os.path.realpath(file_path)
117
- change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY)
118
- return full_path
119
-
120
-
121
- def get_real_data_path(file_path):
122
- targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+']
123
- pattern = re.compile(r'({})'.format('|'.join(targets)))
124
- match = pattern.search(file_path)
125
- if match:
126
- target_index = match.start()
127
- target_path = file_path[target_index:]
128
- return target_path
129
- else:
130
- raise DumpException(DumpException.INVALID_PATH_ERROR)
131
-
132
-
133
- def get_full_data_path(data_path, real_data_path):
134
- if not data_path:
135
- return data_path
136
- full_data_path = os.path.join(real_data_path, data_path)
137
- return os.path.realpath(full_data_path)
138
-
139
-
140
- class UtDataProcessor:
141
- def __init__(self, save_path):
142
- self.save_path = save_path
143
- self.index = 0
144
-
145
- def save_tensors_in_element(self, api_name, element):
146
- self.index = 0
147
- self._save_recursive(api_name, element)
148
-
149
- def _save_recursive(self, api_name, element):
150
- if isinstance(element, torch.Tensor):
151
- api_args = api_name + Const.SEP + str(self.index)
152
- create_directory(self.save_path)
153
- file_path = os.path.join(self.save_path, f'{api_args}.pt')
154
- write_pt(file_path, element.contiguous().cpu().detach())
155
- self.index += 1
156
- elif element is None or isinstance(element, (bool, int, float, str, slice)):
157
- self.index += 1
158
- elif isinstance(element, (list, tuple)):
159
- for item in element:
160
- self._save_recursive(api_name, item)
161
- elif isinstance(element, dict):
162
- for value in element.values():
163
- self._save_recursive(api_name, value)
164
- else:
165
- self.index += 1
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ import re
20
+ from collections import namedtuple
21
+ import importlib
22
+
23
+ import torch
24
+
25
+ try:
26
+ import torch_npu
27
+ except ImportError:
28
+ IS_GPU = True
29
+ else:
30
+ IS_GPU = False
31
+
32
+ from msprobe.pytorch.common.log import logger
33
+ from msprobe.pytorch.common.utils import save_pt
34
+ from msprobe.core.common.file_utils import create_directory
35
+ from msprobe.core.common.const import Const
36
+ from msprobe.core.common.utils import CompareException
37
+
38
+ ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'],
39
+ defaults=['unknown', None, None, None, 0, 0])
40
+
41
+
42
+ class DumpException(CompareException):
43
+ pass
44
+
45
+
46
+ def check_object_type(check_object, allow_type):
47
+ """
48
+ Function Description:
49
+ Check if the object belongs to a certain data type
50
+ Parameter:
51
+ check_object: the object to be checked
52
+ allow_type: legal data type
53
+ Exception Description:
54
+ when invalid data throw exception
55
+ """
56
+ if not isinstance(check_object, allow_type):
57
+ logger.error(f"{check_object} not of {allow_type} type")
58
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
59
+
60
+
61
+ class SoftlinkCheckException(Exception):
62
+ pass
63
+
64
+
65
+ def check_need_convert(api_name):
66
+ convert_type = None
67
+ for key, value in Const.CONVERT_API.items():
68
+ if api_name not in value:
69
+ continue
70
+ else:
71
+ convert_type = key
72
+ return convert_type
73
+
74
+
75
+ def api_info_preprocess(api_name, api_info_dict):
76
+ """
77
+ Function Description:
78
+ Preprocesses the API information.
79
+ Parameter:
80
+ api_name: Name of the API.
81
+ api_info_dict: argument of the API.
82
+ Return api_info_dict:
83
+ convert_type: Type of conversion.
84
+ api_info_dict: Processed argument of the API.
85
+ """
86
+ convert_type = check_need_convert(api_name)
87
+ if api_name == 'cross_entropy':
88
+ api_info_dict = cross_entropy_process(api_info_dict)
89
+ return convert_type, api_info_dict
90
+
91
+
92
+ def cross_entropy_process(api_info_dict):
93
+ """
94
+ Function Description:
95
+ Preprocesses the cross_entropy API information.
96
+ Parameter:
97
+ api_info_dict: argument of the API.
98
+ Return api_info_dict:
99
+ api_info_dict: Processed argument of the API.
100
+ """
101
+ if 'input_args' in api_info_dict and len(api_info_dict['input_args']) > 1 \
102
+ and 'Min' in api_info_dict['input_args'][1]:
103
+ if api_info_dict['input_args'][1]['Min'] <= 0:
104
+ # The second argument in cross_entropy should be -100 or not less than 0
105
+ api_info_dict['input_args'][1]['Min'] = 0
106
+ return api_info_dict
107
+
108
+
109
+ def initialize_save_path(save_path, dir_name):
110
+ data_path = os.path.join(save_path, dir_name)
111
+ create_directory(data_path)
112
+ return data_path
113
+
114
+
115
+ def get_full_data_path(data_path, real_data_path):
116
+ if not data_path:
117
+ return data_path
118
+ full_data_path = os.path.join(real_data_path, data_path)
119
+ return os.path.realpath(full_data_path)
120
+
121
+
122
+ class UtDataProcessor:
123
+ def __init__(self, save_path):
124
+ self.save_path = save_path
125
+ self.index = 0
126
+
127
+ def save_tensors_in_element(self, api_name, element):
128
+ self.index = 0
129
+ self._save_recursive(api_name, element)
130
+
131
+ def _save_recursive(self, api_name, element, depth=0):
132
+ if depth > Const.MAX_DEPTH:
133
+ logger.error(f"Maximum depth of {Const.MAX_DEPTH} exceeded for {api_name}")
134
+ raise DumpException(DumpException.RECURSION_LIMIT_ERROR)
135
+ if isinstance(element, torch.Tensor):
136
+ api_args = api_name + Const.SEP + str(self.index)
137
+ create_directory(self.save_path)
138
+ file_path = os.path.join(self.save_path, f'{api_args}.pt')
139
+ try:
140
+ tensor = element.contiguous().detach().cpu()
141
+ except Exception as err:
142
+ logger.error(f"Failed to transfer tensor to cpu for {api_args}")
143
+ raise DumpException(DumpException.INVALID_DATA_ERROR) from err
144
+ save_pt(tensor, file_path)
145
+ self.index += 1
146
+ elif element is None or isinstance(element, (bool, int, float, str, slice)):
147
+ self.index += 1
148
+ elif isinstance(element, (list, tuple)):
149
+ for item in element:
150
+ self._save_recursive(api_name, item, depth=depth+1)
151
+ elif isinstance(element, dict):
152
+ for value in element.values():
153
+ self._save_recursive(api_name, value, depth=depth+1)
154
+ else:
155
+ self.index += 1
156
+
157
+
158
+ def extract_basic_api_segments(api_full_name):
159
+ """
160
+ Function Description:
161
+ Extract the name of the API.
162
+ Parameter:
163
+ api_full_name: Full name of the API. Example: torch.matmul.0, torch.linalg.inv.0
164
+ Return:
165
+ api_type: Type of api. Example: torch, tensor, etc.
166
+ api_name: Name of api. Example: matmul, linalg.inv, etc.
167
+ """
168
+ api_type = None
169
+ api_parts = api_full_name.split(Const.SEP)
170
+ api_parts_length = len(api_parts)
171
+ if api_parts_length == Const.THREE_SEGMENT:
172
+ api_type, api_name, _ = api_parts
173
+ elif api_parts_length == Const.FOUR_SEGMENT:
174
+ api_type, prefix, api_name, _ = api_parts
175
+ api_name = Const.SEP.join([prefix, api_name])
176
+ else:
177
+ api_name = None
178
+ return api_type, api_name
179
+
180
+
181
+ def extract_detailed_api_segments(full_api_name_with_direction_status):
182
+ """
183
+ Function Description:
184
+ Extract the name of the API.
185
+ Parameter:
186
+ full_api_name_with_direction_status: Full name of the API. Example: torch.matmul.0.forward.output.0
187
+ Return:
188
+ api_name: Name of api. Example: matmul, mul, etc.
189
+ full_api_name: Full name of api. Example: torch.matmul.0
190
+ direction_status: Direction status of api. Example: forward, backward, etc.
191
+ """
192
+ api_type = None
193
+ prefix = None
194
+ api_name = None
195
+ direction_status = None
196
+ api_parts = full_api_name_with_direction_status.split(Const.SEP)
197
+ api_parts_length = len(api_parts)
198
+ if api_parts_length == Const.SIX_SEGMENT:
199
+ api_type, api_name, api_order, direction_status, _, _ = api_parts
200
+ full_api_name = Const.SEP.join([api_type, api_name, api_order])
201
+ elif api_parts_length == Const.SEVEN_SEGMENT:
202
+ api_type, prefix, api_name, api_order, direction_status, _, _ = api_parts
203
+ full_api_name = Const.SEP.join([api_type, prefix, api_name, api_order])
204
+ api_name = Const.SEP.join([prefix, api_name])
205
+ else:
206
+ full_api_name = None
207
+ return api_name, full_api_name, direction_status
208
+
209
+
210
+ def get_module_and_atttribute_name(attribute):
211
+ '''
212
+ Function Description:
213
+ Get the module and attribute name.
214
+ Parameter:
215
+ name: Attribute of a module. Example: torch.float16
216
+ Return:
217
+ module_name: Name of the module. Example: torch.
218
+ attribute_name: Name of the attribute. Example: float16.
219
+ '''
220
+ try:
221
+ module_name, attribute_name = attribute.split(Const.SEP)
222
+ except ValueError as e:
223
+ logger.error(f"Failed to get module and attribute name from {attribute}")
224
+ raise CompareException(CompareException.INVALID_DATA_ERROR) from e
225
+ return module_name, attribute_name
226
+
227
+
228
+ def get_attribute(module_name, attribute_name):
229
+ '''
230
+ Function Description:
231
+ Get the attribute of the module.
232
+ Parameter:
233
+ module_name: Name of the module.
234
+ attribute_name: Name of the attribute.
235
+ '''
236
+ attribute = None
237
+ if module_name not in Const.MODULE_WHITE_LIST:
238
+ logger.error(f"Module {module_name} is not in white list")
239
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
240
+ try:
241
+ module = importlib.import_module(module_name)
242
+ attribute = getattr(module, attribute_name)
243
+ except (ImportError, AttributeError) as e:
244
+ logger.error(f"Failed to get attribute {attribute_name} from module {module_name}: {e}")
245
+ raise CompareException(CompareException.INVALID_ATTRIBUTE_ERROR) from e
246
+ return attribute