mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,74 +1,70 @@
1
- import os
2
- import re
3
-
4
- from msprobe.core.common.const import FileCheckConst
5
- from msprobe.core.common.file_check import FileChecker
6
- from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
- from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
- from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
9
- from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
- from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
-
12
- hf_32_standard_api = ["conv1d", "conv2d"]
13
-
14
-
15
- class Backward_Message:
16
- MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
17
- UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
18
- NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
19
-
20
-
21
- class UtDataInfo:
22
- def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
23
- backward_message, rank=0):
24
- self.bench_grad = bench_grad
25
- self.device_grad = device_grad
26
- self.device_output = device_output
27
- self.bench_output = bench_output
28
- self.grad_in = grad_in
29
- self.in_fwd_data_list = in_fwd_data_list
30
- self.backward_message = backward_message
31
- self.rank = rank
32
-
33
-
34
- def get_validated_result_csv_path(result_csv_path, mode):
35
- if mode not in ['result', 'detail']:
36
- raise ValueError("The csv mode must be result or detail")
37
- result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
38
- file_type=FileCheckConst.CSV_SUFFIX)
39
- validated_result_csv_path = result_csv_path_checker.common_check()
40
- if mode == 'result':
41
- result_csv_name = os.path.basename(validated_result_csv_path)
42
- pattern = r"^accuracy_checking_result_\d{14}\.csv$"
43
- if not re.match(pattern, result_csv_name):
44
- raise ValueError("When continue run ut, please do not modify the result csv name.")
45
- return validated_result_csv_path
46
-
47
-
48
- def get_validated_details_csv_path(validated_result_csv_path):
49
- result_csv_name = os.path.basename(validated_result_csv_path)
50
- details_csv_name = result_csv_name.replace('result', 'details')
51
- details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
52
- details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
53
- ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
54
- validated_details_csv_path = details_csv_path_checker.common_check()
55
- return validated_details_csv_path
56
-
57
-
58
- def exec_api(api_type, api_name, args, kwargs):
59
- if api_type == "Functional":
60
- functional_api = FunctionalOPTemplate(api_name, str, False)
61
- out = functional_api.forward(*args, **kwargs)
62
- if api_type == "Tensor":
63
- tensor_api = TensorOPTemplate(api_name, str, False)
64
- out = tensor_api.forward(*args, **kwargs)
65
- if api_type == "Torch":
66
- torch_api = TorchOPTemplate(api_name, str, False)
67
- out = torch_api.forward(*args, **kwargs)
68
- if api_type == "Aten":
69
- torch_api = AtenOPTemplate(api_name, None, False)
70
- out = torch_api.forward(*args, **kwargs)
71
- if api_type == "NPU":
72
- torch_api = NpuOPTemplate(api_name, None, False)
73
- out = torch_api.forward(*args, **kwargs)
74
- return out
1
+ import os
2
+ import re
3
+
4
+ from msprobe.core.common.const import FileCheckConst
5
+ from msprobe.core.common.file_utils import FileChecker
6
+ from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
7
+ from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
8
+ from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
9
+ from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
10
+ from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
11
+
12
+ hf_32_standard_api = ["conv1d", "conv2d"]
13
+
14
+
15
+ class Backward_Message:
16
+ MULTIPLE_BACKWARD_MESSAGE = "Multiple backward is not supported."
17
+ UNSUPPORT_BACKWARD_MESSAGE = "function with out=... arguments don't support automatic differentiation, skip backward."
18
+ NO_BACKWARD_RESULT_MESSAGE = "function backward result is None, skip backward."
19
+
20
+
21
+ class UtDataInfo:
22
+ def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list,
23
+ backward_message, rank=0):
24
+ self.bench_grad = bench_grad
25
+ self.device_grad = device_grad
26
+ self.device_output = device_output
27
+ self.bench_output = bench_output
28
+ self.grad_in = grad_in
29
+ self.in_fwd_data_list = in_fwd_data_list
30
+ self.backward_message = backward_message
31
+ self.rank = rank
32
+
33
+
34
+ def get_validated_result_csv_path(result_csv_path, mode):
35
+ if mode not in ['result', 'detail']:
36
+ raise ValueError("The csv mode must be result or detail")
37
+ result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE,
38
+ file_type=FileCheckConst.CSV_SUFFIX)
39
+ validated_result_csv_path = result_csv_path_checker.common_check()
40
+ if mode == 'result':
41
+ result_csv_name = os.path.basename(validated_result_csv_path)
42
+ pattern = r"^accuracy_checking_result_\d{14}\.csv$"
43
+ if not re.match(pattern, result_csv_name):
44
+ raise ValueError("When continue run ut, please do not modify the result csv name.")
45
+ return validated_result_csv_path
46
+
47
+
48
+ def get_validated_details_csv_path(validated_result_csv_path):
49
+ result_csv_name = os.path.basename(validated_result_csv_path)
50
+ details_csv_name = result_csv_name.replace('result', 'details')
51
+ details_csv_path = os.path.join(os.path.dirname(validated_result_csv_path), details_csv_name)
52
+ details_csv_path_checker = FileChecker(details_csv_path, FileCheckConst.FILE,
53
+ ability=FileCheckConst.READ_WRITE_ABLE, file_type=FileCheckConst.CSV_SUFFIX)
54
+ validated_details_csv_path = details_csv_path_checker.common_check()
55
+ return validated_details_csv_path
56
+
57
+
58
+ def exec_api(api_type, api_name, device, args, kwargs):
59
+ if api_type == "Functional":
60
+ torch_api = FunctionalOPTemplate(api_name, str, False)
61
+ if api_type == "Tensor":
62
+ torch_api = TensorOPTemplate(api_name, str, False)
63
+ if api_type == "Torch":
64
+ torch_api = TorchOPTemplate(api_name, str, False)
65
+ if api_type == "Aten":
66
+ torch_api = AtenOPTemplate(api_name, None, False)
67
+ if api_type == "NPU":
68
+ torch_api = NpuOPTemplate(api_name, None, False, device)
69
+ out = torch_api.forward(*args, **kwargs)
70
+ return out
@@ -1,5 +1,8 @@
1
- {
2
- "topk": {
3
- "grad_index": 0
4
- }
1
+ {
2
+ "topk": {
3
+ "grad_index": 0
4
+ },
5
+ "npu_fusion_attention": {
6
+ "grad_index": 0
7
+ }
5
8
  }
@@ -1,202 +1,197 @@
1
- import io
2
- import os.path
3
- import time
4
- import re
5
- from pathlib import Path
6
- from multiprocessing import Queue
7
- from typing import Optional, Union, Dict, Any
8
- from collections import namedtuple
9
- from dataclasses import dataclass
10
-
11
- import torch
12
-
13
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
14
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
15
- from msprobe.pytorch.common.utils import logger
16
- from msprobe.pytorch.common.utils import save_pt
17
- from msprobe.core.common.utils import remove_path
18
-
19
-
20
- ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'],
21
- defaults=['unknown', None, None, None, 0, 0])
22
- BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
23
-
24
-
25
- @dataclass
26
- class ATTLConfig:
27
- is_benchmark_device: bool
28
- connect_ip: str
29
- connect_port: int
30
- # storage_config
31
- nfs_path: str = None
32
- tls_path: str = None
33
- check_sum: bool = True
34
- queue_size: int = 50
35
-
36
-
37
- class ATTL:
38
- def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
39
- self.session_id = session_id
40
- self.session_config = session_config
41
- self.logger = logger
42
- self.socket_manager = None
43
- self.data_queue = Queue(maxsize=50)
44
- self.dequeue_list = []
45
- self.message_end = False
46
- self.kill_progress = False
47
- self.check_attl_config()
48
- if self.session_config.nfs_path:
49
- self.nfs_path = Path(self.session_config.nfs_path)
50
- elif self.session_config.is_benchmark_device:
51
-
52
- self.socket_manager = TCPServer(self.session_config.connect_port,
53
- self.data_queue,
54
- self.session_config.check_sum,
55
- self.session_config.tls_path)
56
- self.socket_manager.start()
57
- elif need_dump:
58
- self.socket_manager = TCPClient(self.session_config.connect_ip,
59
- self.session_config.connect_port,
60
- self.session_config.check_sum,
61
- self.session_config.tls_path)
62
- self.socket_manager.start()
63
-
64
- def check_attl_config(self):
65
- if self.session_config.nfs_path:
66
- if os.path.exists(self.session_config.nfs_path):
67
- return
68
- else:
69
- raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
70
- ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
71
- if not re.match(ipv4_pattern, self.session_config.connect_ip):
72
- raise Exception(f"host {self.session_config.connect_ip} is invalid.")
73
- if not (0 < self.session_config.connect_port <= 65535):
74
- raise Exception(f"port {self.session_config.connect_port} is invalid.")
75
-
76
- def stop_serve(self):
77
- if isinstance(self.socket_manager, TCPServer):
78
- self.socket_manager.stop()
79
-
80
- def send(self, buffer: BufferType) -> None:
81
- """
82
- npu major in 'send' (client)
83
- """
84
- # know receiver receive and go next
85
- if isinstance(buffer, ApiData):
86
- buffer = move2target_device(buffer, torch.device('cpu'))
87
-
88
- if 'device' in buffer.kwargs:
89
- buffer.kwargs.pop('device')
90
- rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
91
- step = buffer.step if hasattr(buffer, "step") else 0
92
- io_buff = io.BytesIO()
93
- try:
94
- torch.save(buffer, io_buff)
95
- except Exception as e:
96
- self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
97
- return
98
- data = io_buff.getvalue()
99
- self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
100
-
101
- def recv(self, timeout_ms=0) -> Optional[BufferType]:
102
- buffer = None
103
- while buffer is None:
104
- if timeout_ms > 0:
105
- time.sleep(timeout_ms / 1000.0)
106
- if buffer is None and not self.data_queue.empty():
107
- buffer = self.data_queue.get()
108
- break
109
- if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
110
- break
111
- if self.message_end and self.data_queue.empty():
112
- buffer = b"KILL_CONFIRM"
113
- self.kill_progress = True
114
- break
115
- time.sleep(0.1) # waiting outside the lock before next attempt
116
- if buffer is None:
117
- # this is a result of a timeout
118
- self.logger.info(f"RECEIVE API DATA TIMED OUT")
119
- else:
120
- if buffer == b"STOP_":
121
- return "STOP_"
122
- if buffer == b"KILL_":
123
- self.message_end = True
124
- return "STOP_"
125
- if buffer == b"KILL_CONFIRM":
126
- self.kill_progress = True
127
- return "KILL_"
128
- buffer = io.BytesIO(buffer)
129
- try:
130
- buffer = torch.load(buffer, map_location="cpu")
131
- except Exception as e:
132
- self.logger.warning("there is something error. please check it. %s", e)
133
- if isinstance(buffer, bytes):
134
- return None
135
- if isinstance(buffer, str):
136
- return buffer
137
-
138
- return buffer
139
-
140
- def upload(self, buffer: BufferType):
141
- if isinstance(buffer, ApiData):
142
- buffer = move2target_device(buffer, torch.device('cpu'))
143
- file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
144
- else:
145
- file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
146
-
147
- try:
148
- save_pt(buffer, file_path)
149
- except Exception as e:
150
- self.logger.warning("there is something error in save_pt. please check it. %s", e)
151
-
152
- def download(self):
153
- for file_type in ("start*", "*.pt", "end*"):
154
- cur_file = next(self.nfs_path.glob(file_type), None)
155
- if cur_file is not None:
156
- break
157
-
158
- if cur_file is None:
159
- return None
160
- else:
161
- buffer = None
162
- try:
163
- buffer = torch.load(cur_file)
164
- except Exception as e:
165
- self.logger.warning("there is something error. please check it. %s", e)
166
- remove_path(cur_file)
167
- return buffer
168
-
169
-
170
- def move2device_exec(obj, device):
171
- if isinstance(obj, (tuple, list)):
172
- data_list = [move2device_exec(val, device) for val in obj]
173
- return data_list if isinstance(obj, list) else tuple(data_list)
174
- if isinstance(obj, dict):
175
- return {key: move2device_exec(val, device) for key, val in obj.items()}
176
- elif isinstance(obj, torch.Tensor):
177
- obj = obj.detach()
178
- if obj.device.type != device:
179
- obj = obj.to(device)
180
- return obj
181
- elif "return_types" in str(type(obj)):
182
- return move2device_exec(tuple(obj), device)
183
- elif isinstance(obj, torch._C.device):
184
- return torch.device(device)
185
- else:
186
- return obj
187
-
188
-
189
- def move2target_device(buffer: ApiData, target_device):
190
- # handle args
191
- new_args = move2device_exec(buffer.args, target_device)
192
-
193
- # handle kwargs
194
- new_kwargs = move2device_exec(buffer.kwargs, target_device)
195
-
196
- # handle result
197
- new_results = move2device_exec(buffer.result, target_device)
198
-
199
- if target_device == torch.device('cpu') or target_device == "cpu":
200
- return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
201
- else:
202
- return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)
1
+ import glob
2
+ import os.path
3
+ import time
4
+ import re
5
+ from multiprocessing import Queue
6
+ from typing import Optional, Union, Dict, Any
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+
11
+ from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData
12
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient
13
+ from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer
14
+ from msprobe.pytorch.common.utils import logger
15
+ from msprobe.core.common.file_utils import remove_path
16
+ from msprobe.pytorch.common.utils import save_api_data, load_api_data, save_pt, load_pt
17
+
18
+ BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]]
19
+
20
+
21
+ @dataclass
22
+ class ATTLConfig:
23
+ is_benchmark_device: bool
24
+ connect_ip: str
25
+ connect_port: int
26
+ # storage_config
27
+ nfs_path: str = None
28
+ tls_path: str = None
29
+ check_sum: bool = True
30
+ queue_size: int = 50
31
+
32
+
33
+ class ATTL:
34
+ def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None:
35
+ self.session_id = session_id
36
+ self.session_config = session_config
37
+ self.logger = logger
38
+ self.socket_manager = None
39
+ self.data_queue = Queue(maxsize=50)
40
+ self.dequeue_list = []
41
+ self.message_end = False
42
+ self.kill_progress = False
43
+ self.check_attl_config()
44
+ if self.session_config.nfs_path:
45
+ self.nfs_path = self.session_config.nfs_path
46
+ elif self.session_config.is_benchmark_device:
47
+
48
+ self.socket_manager = TCPServer(self.session_config.connect_port,
49
+ self.data_queue,
50
+ self.session_config.check_sum,
51
+ self.session_config.tls_path)
52
+ self.socket_manager.start()
53
+ elif need_dump:
54
+ self.socket_manager = TCPClient(self.session_config.connect_ip,
55
+ self.session_config.connect_port,
56
+ self.session_config.check_sum,
57
+ self.session_config.tls_path)
58
+ self.socket_manager.start()
59
+
60
+ def check_attl_config(self):
61
+ if self.session_config.nfs_path:
62
+ if os.path.exists(self.session_config.nfs_path):
63
+ return
64
+ else:
65
+ raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.")
66
+ ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
67
+ if not re.match(ipv4_pattern, self.session_config.connect_ip):
68
+ raise Exception(f"host {self.session_config.connect_ip} is invalid.")
69
+ if not (0 < self.session_config.connect_port <= 65535):
70
+ raise Exception(f"port {self.session_config.connect_port} is invalid.")
71
+
72
+ def stop_serve(self):
73
+ if isinstance(self.socket_manager, TCPServer):
74
+ self.socket_manager.stop()
75
+
76
+ def send(self, buffer: BufferType) -> None:
77
+ """
78
+ npu major in 'send' (client)
79
+ """
80
+ # know receiver receive and go next
81
+ if isinstance(buffer, ApiData):
82
+ buffer = move2target_device(buffer, torch.device('cpu'))
83
+
84
+ if 'device' in buffer.kwargs:
85
+ buffer.kwargs.pop('device')
86
+ rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0
87
+ step = buffer.step if hasattr(buffer, "step") else 0
88
+ try:
89
+ io_buff = save_api_data(buffer)
90
+ except Exception as e:
91
+ self.logger.info(f"{buffer.name} can not be saved, skip: {e}")
92
+ return
93
+ data = io_buff.getvalue()
94
+ self.socket_manager.add_to_sending_queue(data, rank=rank, step=step)
95
+
96
+ def recv(self, timeout_ms=0) -> Optional[BufferType]:
97
+ buffer = None
98
+ while buffer is None:
99
+ if timeout_ms > 0:
100
+ time.sleep(timeout_ms / 1000.0)
101
+ if buffer is None and not self.data_queue.empty():
102
+ buffer = self.data_queue.get()
103
+ break
104
+ if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None
105
+ break
106
+ if self.message_end and self.data_queue.empty():
107
+ buffer = b"KILL_CONFIRM"
108
+ self.kill_progress = True
109
+ break
110
+ time.sleep(0.1) # waiting outside the lock before next attempt
111
+ if buffer is None:
112
+ # this is a result of a timeout
113
+ self.logger.info(f"RECEIVE API DATA TIMED OUT")
114
+ else:
115
+ if buffer == b"STOP_":
116
+ return "STOP_"
117
+ if buffer == b"KILL_":
118
+ self.message_end = True
119
+ return "STOP_"
120
+ if buffer == b"KILL_CONFIRM":
121
+ self.kill_progress = True
122
+ return "KILL_"
123
+ try:
124
+ buffer = load_api_data(buffer)
125
+ except Exception as e:
126
+ self.logger.warning("there is something error. please check it. %s", e)
127
+ if isinstance(buffer, bytes):
128
+ return None
129
+ if isinstance(buffer, str):
130
+ return buffer
131
+
132
+ return buffer
133
+
134
+ def upload(self, buffer: BufferType):
135
+ if isinstance(buffer, ApiData):
136
+ buffer = move2target_device(buffer, torch.device('cpu'))
137
+ file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
138
+ else:
139
+ file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
140
+
141
+ try:
142
+ save_pt(buffer, file_path)
143
+ except Exception as e:
144
+ self.logger.warning("there is something error in save_pt. please check it. %s", e)
145
+
146
+ def download(self):
147
+ buffer = None
148
+ cur_file = None
149
+ for file_type in ("start*", "*.pt", "end*"):
150
+ pattern = os.path.join(self.nfs_path, file_type)
151
+ files = glob.glob(pattern)
152
+ if len(files) > 0:
153
+ cur_file = files[0]
154
+ break
155
+
156
+ if cur_file is not None:
157
+ try:
158
+ buffer = load_pt(cur_file)
159
+ except Exception as e:
160
+ self.logger.warning("there is something error. please check it. %s", e)
161
+ remove_path(cur_file)
162
+ return buffer
163
+
164
+
165
+ def move2device_exec(obj, device):
166
+ if isinstance(obj, (tuple, list)):
167
+ data_list = [move2device_exec(val, device) for val in obj]
168
+ return data_list if isinstance(obj, list) else tuple(data_list)
169
+ if isinstance(obj, dict):
170
+ return {key: move2device_exec(val, device) for key, val in obj.items()}
171
+ elif isinstance(obj, torch.Tensor):
172
+ obj = obj.detach()
173
+ if obj.device.type != device:
174
+ obj = obj.to(device)
175
+ return obj
176
+ elif "return_types" in str(type(obj)):
177
+ return move2device_exec(tuple(obj), device)
178
+ elif isinstance(obj, torch._C.device):
179
+ return torch.device(device)
180
+ else:
181
+ return obj
182
+
183
+
184
+ def move2target_device(buffer: ApiData, target_device):
185
+ # handle args
186
+ new_args = move2device_exec(buffer.args, target_device)
187
+
188
+ # handle kwargs
189
+ new_kwargs = move2device_exec(buffer.kwargs, target_device)
190
+
191
+ # handle result
192
+ new_results = move2device_exec(buffer.result, target_device)
193
+
194
+ if target_device == torch.device('cpu') or target_device == "cpu":
195
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank)
196
+ else:
197
+ return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank)