mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,74 +1,213 @@
1
- import os
2
- import re
3
-
4
- from msprobe.core.common.const import FileCheckConst
5
- from msprobe.core.common.file_check import FileChecker
6
- from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
- from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
9
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
-
12
- hf_32_standard_api = ["conv1d", "conv2d"]
13
-
14
-
15
- class Backward_Message:
16
- MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
17
- UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
18
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
19
-
20
-
21
- class UtDataInfo:
22
- def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
23
- backward_message, rank=0):
24
- self.bench_grad = bench_grad
25
- self.device_grad = device_grad
26
- self.device_output = device_output
27
- self.bench_output = bench_output
28
- self.grad_in = grad_in
29
- self.in_fwd_data_list = in_fwd_data_list
30
- self.backward_message = backward_message
31
- self.rank = rank
32
-
33
-
34
- def get_validated_result_csv_path(result_csv_path, mode):
35
- if mode not in ['result', 'detail']:
36
- raise ValueError("The csv mode must be result or detail")
37
- result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
38
- file_type=FileCheckConst.CSV_SUFFIX)
39
- validated_result_csv_path = result_csv_path_checker.common_check()
40
- if mode == 'result':
41
- result_csv_name = os.path.basename(validated_result_csv_path)
42
- pattern = r"^accuracy_checking_result_\d{14}\.csv$"
43
- if not re.match(pattern, result_csv_name):
44
- raise ValueError("When continue run ut, please do not modify the result csv name.")
45
- return validated_result_csv_path
46
-
47
-
48
- def get_validated_details_csv_path(validated_result_csv_path):
49
- result_csv_name = os.path.basename(validated_result_csv_path)
50
- details_csv_name = result_csv_name.replace('result', 'details')
51
- details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
52
- details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
53
- ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
54
- validated_details_csv_path = details_csv_path_checker.common_check()
55
- return validated_details_csv_path
56
-
57
-
58
- def exec_api(api_type, api_name, args, kwargs):
59
- if api_type == "Functional":
60
- functional_api = FunctionalOPTemplate(api_name, str, False)
61
- out = functional_api.forward(*args, **kwargs)
62
- if api_type == "Tensor":
63
- tensor_api = TensorOPTemplate(api_name, str, False)
64
- out = tensor_api.forward(*args, **kwargs)
65
- if api_type == "Torch":
66
- torch_api = TorchOPTemplate(api_name, str, False)
67
- out = torch_api.forward(*args, **kwargs)
68
- if api_type == "Aten":
69
- torch_api = AtenOPTemplate(api_name, None, False)
70
- out = torch_api.forward(*args, **kwargs)
71
- if api_type == "NPU":
72
- torch_api = NpuOPTemplate(api_name, None, False)
73
- out = torch_api.forward(*args, **kwargs)
74
- return out
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ import re
20
+ import torch
21
+
22
+ try:
23
+ import torch_npu
24
+ except ImportError:
25
+ current_device = "cuda"
26
+ else:
27
+ current_device = "npu"
28
+
29
+ from msprobe.core.common.const import FileCheckConst, Const, CompareConst
30
+ from msprobe.core.common.file_utils import FileChecker
31
+ from msprobe.core.common.log import logger
32
+ from msprobe.core.common.utils import CompareException
33
+ from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
34
+ from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
35
+ from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
36
+ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
37
+ from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
38
+
39
+ hf_32_standard_api = ["conv1d", "conv2d"]
40
+ not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
41
+ not_raise_dtype_set = {'type_as'}
42
+
43
+ PRECISION_MAPPING = {
44
+ torch.float16: torch.float32,
45
+ torch.bfloat16: torch.float32,
46
+ torch.float32: torch.float64
47
+ }
48
+
49
+
50
+ class BackwardMessage:
51
+ MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
52
+ UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, " \
53
+ "skip backward."
54
+ NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
55
+
56
+
57
+ class UtDataInfo:
58
+ def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
59
+ backward_message, rank=0):
60
+ self.bench_grad = bench_grad
61
+ self.device_grad = device_grad
62
+ self.device_output = device_output
63
+ self.bench_output = bench_output
64
+ self.grad_in = grad_in
65
+ self.in_fwd_data_list = in_fwd_data_list
66
+ self.backward_message = backward_message
67
+ self.rank = rank
68
+
69
+
70
+ def get_validated_result_csv_path(result_csv_path, mode):
71
+ if mode not in ['result', 'detail']:
72
+ raise ValueError("The csv mode must be result or detail")
73
+ result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
74
+ file_type=FileCheckConst.CSV_SUFFIX)
75
+ validated_result_csv_path = result_csv_path_checker.common_check()
76
+ if mode == 'result':
77
+ result_csv_name = os.path.basename(validated_result_csv_path)
78
+ pattern = r"^accuracy_checking_result_\d{14}\.csv$"
79
+ if not re.match(pattern, result_csv_name):
80
+ raise ValueError("When continue run ut, please do not modify the result csv name.")
81
+ return validated_result_csv_path
82
+
83
+
84
+ def get_validated_details_csv_path(validated_result_csv_path):
85
+ result_csv_name = os.path.basename(validated_result_csv_path)
86
+ details_csv_name = result_csv_name.replace('result', 'details')
87
+ details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
88
+ details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
89
+ ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
90
+ validated_details_csv_path = details_csv_path_checker.common_check()
91
+ return validated_details_csv_path
92
+
93
+
94
+ def exec_api(api_type, api_name, device, args, kwargs):
95
+ if api_type == "Functional":
96
+ torch_api = FunctionalOPTemplate(api_name, str, False)
97
+ if api_type == "Tensor":
98
+ torch_api = TensorOPTemplate(api_name, str, False)
99
+ if api_type == "Torch":
100
+ torch_api = TorchOPTemplate(api_name, str, False)
101
+ if api_type == "Aten":
102
+ torch_api = AtenOPTemplate(api_name, None, False)
103
+ if api_type == "NPU":
104
+ torch_api = NpuOPTemplate(api_name, None, False, device)
105
+ out = torch_api.forward(*args, **kwargs)
106
+ return out
107
+
108
+
109
+ def deal_detach(arg, to_detach=True):
110
+ return arg.detach() if to_detach else arg
111
+
112
+
113
+ def raise_bench_data_dtype(api_name, arg, raise_dtype=None):
114
+ '''
115
+ 将标杆数据的dtype转换为raise_dtype
116
+ 输入:
117
+ api_name:api名称
118
+ arg:标杆输入
119
+ raise_dtype:需要转换的dtype
120
+ 输出:
121
+ arg: 转换dtype的标杆输入
122
+ '''
123
+ if api_name in hf_32_standard_api and arg.dtype == torch.float32:
124
+ return arg
125
+ if raise_dtype is None or arg.dtype not in PRECISION_MAPPING or raise_dtype == arg.dtype:
126
+ return arg
127
+ return arg.type(raise_dtype)
128
+
129
+
130
+ def generate_device_params(input_args, input_kwargs, need_backward, api_name):
131
+ def recursive_arg_to_device(arg_in, to_detach, depth=0):
132
+ if depth > Const.MAX_DEPTH:
133
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
134
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
135
+ if isinstance(arg_in, (list, tuple)):
136
+ return type(arg_in)(recursive_arg_to_device(arg, to_detach, depth=depth+1) for arg in arg_in)
137
+ elif isinstance(arg_in, torch.Tensor):
138
+ if need_backward and arg_in.requires_grad:
139
+ arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_()
140
+ temp_arg_in = arg_in * 1
141
+ arg_in = temp_arg_in.type_as(arg_in)
142
+ arg_in.retain_grad()
143
+ return arg_in
144
+ else:
145
+ return deal_detach(arg_in.clone(), to_detach).to(current_device)
146
+ else:
147
+ return arg_in
148
+
149
+ is_detach = api_name not in not_detach_set
150
+ device_args = recursive_arg_to_device(input_args, is_detach)
151
+ device_kwargs = \
152
+ {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()}
153
+ return device_args, device_kwargs
154
+
155
+
156
+ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
157
+ def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None, depth=0):
158
+ if depth > Const.MAX_DEPTH:
159
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
160
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
161
+ if isinstance(arg_in, (list, tuple)):
162
+ return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype, depth=depth+1)
163
+ for arg in arg_in)
164
+ elif isinstance(arg_in, torch.Tensor):
165
+ if need_backward and arg_in.requires_grad:
166
+ arg_in = deal_detach(raise_bench_data_dtype(
167
+ api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_()
168
+ temp_arg_in = arg_in * 1
169
+ arg_in = temp_arg_in.type_as(arg_in)
170
+ arg_in.retain_grad()
171
+ return arg_in
172
+ else:
173
+ return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach)
174
+ else:
175
+ return arg_in
176
+
177
+ def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
178
+ if arg_in.dtype in PRECISION_MAPPING:
179
+ return True
180
+ if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
181
+ return True
182
+ return False
183
+
184
+ def recursive_find_dtypes(arg_in, kwargs=None, check_kwargs=False, depth=0):
185
+ if depth > Const.MAX_DEPTH:
186
+ logger.error("The depth of arg_in is too large, please check the arg_in.")
187
+ raise CompareException(CompareException.RECURSION_LIMIT_ERROR)
188
+ if isinstance(arg_in, (list, tuple)):
189
+ return set().union(*tuple(recursive_find_dtypes(arg, kwargs, check_kwargs=check_kwargs, depth=depth+1) for arg in arg_in))
190
+ elif isinstance(arg_in, torch.Tensor) and is_tensor_with_raise_precision(arg_in, check_kwargs):
191
+ return set([arg_in.dtype])
192
+ elif isinstance(arg_in, dict) and check_kwargs:
193
+ return set().union(*tuple(recursive_find_dtypes(v, kwargs, check_kwargs=True, depth=depth+1) for v in arg_in.values()))
194
+ return set()
195
+
196
+ raise_dtype = None
197
+ need_raise_dtypes = recursive_find_dtypes(input_args)
198
+ need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
199
+ if len(need_raise_dtypes) == 1:
200
+ raise_dtype = PRECISION_MAPPING.get(need_raise_dtypes.pop(), torch.float32)
201
+ elif len(need_raise_dtypes) >= 2:
202
+ raise_dtype = torch.float32
203
+
204
+ raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype
205
+ is_detach = api_name not in not_detach_set
206
+ cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype)
207
+ cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()}
208
+ return cpu_args, cpu_kwargs
209
+
210
+
211
+ def record_skip_info(api_full_name, compare, compare_alg_results):
212
+ result_info = (api_full_name, CompareConst.SKIP, CompareConst.SKIP, [compare_alg_results], None, 0)
213
+ compare.record_results(result_info)
@@ -1,5 +1,8 @@
1
- {
2
- "topk": {
3
- "grad_index": 0
4
- }
1
+ {
2
+ "topk": {
3
+ "grad_index": 0
4
+ },
5
+ "npu_fusion_attention": {
6
+ "grad_index": 0
7
+ }
5
8
  }