mindstudio-probe 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +131 -237
  7. msprobe/__init__.py +16 -1
  8. msprobe/{config/config.json → config.json} +47 -49
  9. msprobe/core/advisor/advisor.py +124 -124
  10. msprobe/core/advisor/advisor_const.py +58 -59
  11. msprobe/core/advisor/advisor_result.py +58 -58
  12. msprobe/core/common/const.py +402 -318
  13. msprobe/core/common/exceptions.py +99 -99
  14. msprobe/core/common/{file_check.py → file_utils.py} +523 -283
  15. msprobe/core/common/inplace_op_checker.py +38 -0
  16. msprobe/core/common/inplace_ops.yaml +251 -0
  17. msprobe/core/common/log.py +86 -69
  18. msprobe/core/common/utils.py +371 -616
  19. msprobe/core/common_config.py +78 -71
  20. msprobe/core/compare/acc_compare.py +472 -298
  21. msprobe/core/compare/check.py +180 -95
  22. msprobe/core/compare/compare_cli.py +69 -49
  23. msprobe/core/compare/highlight.py +259 -222
  24. msprobe/core/compare/multiprocessing_compute.py +174 -149
  25. msprobe/core/compare/npy_compare.py +310 -295
  26. msprobe/core/compare/utils.py +464 -429
  27. msprobe/core/data_dump/data_collector.py +153 -144
  28. msprobe/core/data_dump/data_processor/base.py +337 -293
  29. msprobe/core/data_dump/data_processor/factory.py +76 -59
  30. msprobe/core/data_dump/data_processor/mindspore_processor.py +192 -198
  31. msprobe/core/data_dump/data_processor/pytorch_processor.py +383 -389
  32. msprobe/core/data_dump/json_writer.py +117 -116
  33. msprobe/core/data_dump/scope.py +194 -178
  34. msprobe/core/grad_probe/constant.py +74 -70
  35. msprobe/core/grad_probe/grad_compare.py +170 -175
  36. msprobe/core/grad_probe/utils.py +77 -52
  37. msprobe/docs/01.installation.md +99 -0
  38. msprobe/docs/02.config_introduction.md +137 -0
  39. msprobe/docs/03.config_examples.md +237 -0
  40. msprobe/docs/04.acl_config_examples.md +78 -0
  41. msprobe/docs/05.data_dump_PyTorch.md +326 -0
  42. msprobe/docs/06.data_dump_MindSpore.md +285 -0
  43. msprobe/docs/07.accuracy_checker_PyTorch.md +297 -0
  44. msprobe/docs/08.accuracy_checker_online_PyTorch.md +238 -0
  45. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  46. msprobe/docs/10.accuracy_compare_PyTorch.md +327 -0
  47. msprobe/docs/11.accuracy_compare_MindSpore.md +333 -0
  48. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  49. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  50. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  51. msprobe/docs/15.free_benchmarking_PyTorch.md +170 -0
  52. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  53. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +205 -207
  54. msprobe/{pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md → docs/18.online_dispatch.md} +89 -90
  55. msprobe/docs/FAQ.md +189 -0
  56. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  57. msprobe/docs/img/free_benchmark_framework.png +0 -0
  58. msprobe/docs/img/ms_dump.png +0 -0
  59. msprobe/docs/img/ms_layer.png +0 -0
  60. msprobe/docs/img/pt_dump.png +0 -0
  61. msprobe/mindspore/__init__.py +2 -1
  62. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +278 -245
  63. msprobe/mindspore/api_accuracy_checker/api_info.py +76 -69
  64. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  65. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  66. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  67. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  68. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  69. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  70. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  71. msprobe/mindspore/cell_processor.py +58 -34
  72. msprobe/mindspore/common/const.py +108 -87
  73. msprobe/mindspore/common/log.py +37 -37
  74. msprobe/mindspore/common/utils.py +97 -57
  75. msprobe/mindspore/compare/distributed_compare.py +62 -75
  76. msprobe/mindspore/compare/layer_mapping.py +146 -0
  77. msprobe/mindspore/compare/modify_mapping.py +107 -0
  78. msprobe/mindspore/compare/ms_compare.py +357 -117
  79. msprobe/mindspore/compare/ms_graph_compare.py +364 -317
  80. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  81. msprobe/mindspore/debugger/debugger_config.py +69 -74
  82. msprobe/mindspore/debugger/precision_debugger.py +150 -107
  83. msprobe/mindspore/dump/dump_tool_factory.py +50 -35
  84. msprobe/mindspore/dump/hook_cell/api_registry.py +128 -104
  85. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  86. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  87. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +994 -925
  88. msprobe/mindspore/dump/hook_cell/wrap_api.py +121 -0
  89. msprobe/mindspore/dump/jit_dump.py +96 -56
  90. msprobe/mindspore/dump/kernel_graph_dump.py +75 -60
  91. msprobe/mindspore/dump/kernel_kbyk_dump.py +79 -65
  92. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +131 -116
  93. msprobe/mindspore/free_benchmark/common/config.py +27 -12
  94. msprobe/mindspore/free_benchmark/common/handler_params.py +32 -17
  95. msprobe/mindspore/free_benchmark/common/utils.py +85 -71
  96. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  97. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +57 -42
  98. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +122 -107
  99. msprobe/mindspore/free_benchmark/handler/base_handler.py +105 -90
  100. msprobe/mindspore/free_benchmark/handler/check_handler.py +56 -41
  101. msprobe/mindspore/free_benchmark/handler/fix_handler.py +51 -36
  102. msprobe/mindspore/free_benchmark/handler/handler_factory.py +36 -21
  103. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +82 -67
  104. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +36 -21
  105. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +78 -63
  106. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +77 -0
  107. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +49 -34
  108. msprobe/mindspore/free_benchmark/perturbation/no_change.py +27 -12
  109. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +44 -27
  110. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +48 -33
  111. msprobe/mindspore/grad_probe/global_context.py +100 -91
  112. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  113. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  114. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  115. msprobe/mindspore/grad_probe/hook.py +94 -92
  116. msprobe/mindspore/grad_probe/utils.py +29 -28
  117. msprobe/mindspore/ms_config.py +128 -126
  118. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +60 -45
  119. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +49 -34
  120. msprobe/mindspore/runtime.py +4 -4
  121. msprobe/mindspore/service.py +297 -354
  122. msprobe/mindspore/task_handler_factory.py +24 -24
  123. msprobe/msprobe.py +105 -107
  124. msprobe/pytorch/__init__.py +23 -4
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +70 -55
  126. msprobe/pytorch/api_accuracy_checker/common/utils.py +246 -165
  127. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +230 -213
  128. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +632 -581
  129. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +416 -381
  132. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +90 -73
  133. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +265 -244
  134. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  135. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +370 -332
  136. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +221 -199
  137. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +150 -134
  138. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +518 -581
  139. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +213 -74
  140. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  141. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +218 -202
  142. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +370 -324
  143. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +227 -204
  144. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  145. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +244 -218
  146. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  147. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  148. msprobe/pytorch/bench_functions/__init__.py +30 -15
  149. msprobe/pytorch/bench_functions/apply_adam_w.py +43 -28
  150. msprobe/pytorch/bench_functions/confusion_transpose.py +34 -19
  151. msprobe/pytorch/bench_functions/fast_gelu.py +70 -55
  152. msprobe/pytorch/bench_functions/layer_norm_eval.py +21 -6
  153. msprobe/pytorch/bench_functions/linear.py +27 -12
  154. msprobe/pytorch/bench_functions/matmul_backward.py +63 -48
  155. msprobe/pytorch/bench_functions/npu_fusion_attention.py +538 -421
  156. msprobe/pytorch/bench_functions/rms_norm.py +30 -15
  157. msprobe/pytorch/bench_functions/rotary_mul.py +71 -52
  158. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +41 -26
  159. msprobe/pytorch/bench_functions/swiglu.py +70 -55
  160. msprobe/pytorch/common/__init__.py +17 -2
  161. msprobe/pytorch/common/compare_script.template +14 -14
  162. msprobe/pytorch/common/log.py +33 -32
  163. msprobe/pytorch/common/parse_json.py +54 -39
  164. msprobe/pytorch/common/utils.py +310 -300
  165. msprobe/pytorch/compare/distributed_compare.py +66 -66
  166. msprobe/pytorch/compare/mapping.yaml +607 -607
  167. msprobe/pytorch/compare/match.py +49 -33
  168. msprobe/pytorch/compare/pt_compare.py +82 -40
  169. msprobe/pytorch/debugger/debugger_config.py +108 -95
  170. msprobe/pytorch/debugger/precision_debugger.py +173 -125
  171. msprobe/pytorch/free_benchmark/__init__.py +23 -8
  172. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  173. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  174. msprobe/pytorch/free_benchmark/common/enums.py +65 -37
  175. msprobe/pytorch/free_benchmark/common/params.py +144 -129
  176. msprobe/pytorch/free_benchmark/common/utils.py +118 -102
  177. msprobe/pytorch/free_benchmark/compare/grad_saver.py +200 -179
  178. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +119 -104
  179. msprobe/pytorch/free_benchmark/main.py +120 -105
  180. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +28 -13
  181. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +56 -41
  182. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +105 -90
  183. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +119 -104
  184. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +87 -63
  185. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +83 -68
  186. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +43 -28
  187. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +60 -45
  188. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +34 -19
  189. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +256 -217
  190. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +54 -39
  191. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +38 -23
  192. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +45 -30
  193. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +185 -170
  194. msprobe/pytorch/function_factory.py +91 -75
  195. msprobe/pytorch/functional/module_dump.py +84 -0
  196. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  197. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  198. msprobe/pytorch/hook_module/__init__.py +16 -1
  199. msprobe/pytorch/hook_module/api_registry.py +166 -161
  200. msprobe/pytorch/hook_module/hook_module.py +118 -120
  201. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  202. msprobe/pytorch/hook_module/utils.py +28 -29
  203. msprobe/pytorch/hook_module/wrap_aten.py +111 -110
  204. msprobe/pytorch/hook_module/wrap_distributed.py +77 -78
  205. msprobe/pytorch/hook_module/wrap_functional.py +104 -105
  206. msprobe/pytorch/hook_module/wrap_npu_custom.py +85 -84
  207. msprobe/pytorch/hook_module/wrap_tensor.py +69 -71
  208. msprobe/pytorch/hook_module/wrap_torch.py +84 -86
  209. msprobe/pytorch/hook_module/wrap_vf.py +60 -62
  210. msprobe/pytorch/module_processer.py +153 -138
  211. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  212. msprobe/pytorch/online_dispatch/compare.py +235 -236
  213. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  214. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  215. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  216. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +57 -49
  217. msprobe/pytorch/online_dispatch/utils.py +127 -146
  218. msprobe/pytorch/parse.py +19 -4
  219. msprobe/pytorch/parse_tool/cli.py +31 -32
  220. msprobe/pytorch/parse_tool/lib/compare.py +259 -271
  221. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  222. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  224. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  225. msprobe/pytorch/parse_tool/lib/parse_tool.py +161 -158
  226. msprobe/pytorch/parse_tool/lib/utils.py +320 -321
  227. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  228. msprobe/pytorch/pt_config.py +317 -187
  229. msprobe/pytorch/service.py +311 -252
  230. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  231. msprobe/config/README.md +0 -539
  232. msprobe/mindspore/doc/compare.md +0 -58
  233. msprobe/mindspore/doc/dump.md +0 -217
  234. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  235. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  236. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  237. msprobe/pytorch/doc/FAQ.md +0 -193
  238. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  239. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  240. msprobe/pytorch/doc/dump.md +0 -260
  241. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  242. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  243. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  244. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  245. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  246. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  247. msprobe/pytorch/functional/data_processor.py +0 -0
  248. msprobe/pytorch/functional/dump_module.py +0 -39
  249. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  256. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  257. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  258. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  259. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  260. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  261. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  263. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  264. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  265. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  266. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  267. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  268. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  269. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  270. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  271. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  272. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  273. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  274. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  275. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  276. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  277. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  278. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,156 +1,155 @@
1
- import os
2
- import json
3
- import copy
4
- from datetime import datetime, timezone
5
-
6
- import pandas as pd
7
- import torch
8
- from msprobe.pytorch.common.log import logger
9
- from msprobe.core.common.file_check import FileOpen
10
- from .utils import np_save_data
11
-
12
-
13
- class DispatchRunParam:
14
- def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator):
15
- # static parameters are initialized by constructors, and dynamic parameters are constructed at run time
16
- self.debug_flag = debug_flag
17
- self.device_id = device_id
18
- self.root_npu_path = root_npu_path
19
- self.root_cpu_path = root_cpu_path
20
- self.process_num = process_num
21
- self.process_flag = False
22
- self.func_name = None
23
- self.func_namespace = None
24
- self.aten_api = None
25
- self.aten_api_overload_name = None
26
- self.single_api_index = None
27
- self.api_index = None
28
- self.dump_flag = None
29
- self.auto_dump_flag = None
30
- self.comparator = comparator
31
-
32
-
33
- class DisPatchDataInfo:
34
- def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out, lock):
35
- self.cpu_args = cpu_args
36
- self.cpu_kwargs = cpu_kwargs
37
- self.all_summary = all_summary
38
- self.func = func
39
- self.npu_out_cpu = npu_out_cpu
40
- self.cpu_out = cpu_out
41
- self.lock = lock
42
-
43
-
44
- class TimeStatistics:
45
- def __init__(self, name_tag, run_param, timeout=5):
46
- self.debug = run_param.debug_flag
47
- if self.debug:
48
- self.fun = run_param.func_name
49
- self.device = run_param.device_id
50
- self.process = run_param.process_num
51
- self.index = run_param.single_api_index
52
- self.tag = name_tag
53
- self.timeout = timeout
54
- self.time = None
55
-
56
- def __enter__(self):
57
- if self.debug:
58
- self.time = datetime.now(tz=timezone.utc)
59
- logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
60
- f'Id[{self.index}]')
61
-
62
- def __exit__(self, exc_type, exc_val, exc_tb):
63
- if self.debug:
64
- cost_time = datetime.now(tz=timezone.utc) - self.time
65
- time_cost = f'Time[{self.tag}]-EXIT: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
66
- f'Id[{self.index}], time[{cost_time}]'
67
- hot_time_cost = "Hotspot " + time_cost
68
-
69
- if cost_time.total_seconds() > self.timeout:
70
- logger.info(hot_time_cost)
71
- else:
72
- logger.info(time_cost)
73
-
74
-
75
- def support_basic_type(data):
76
- if isinstance(data, (bool, int, float, torch.Tensor)):
77
- return True
78
- return False
79
-
80
-
81
- def dump_data(data, prefix, dump_path):
82
- if isinstance(data, (tuple, list)) and data:
83
- for i, item in enumerate(data):
84
- dump_data(item, "{}.{}".format(prefix, i), dump_path)
85
- return
86
- elif support_basic_type(data):
87
- if isinstance(data, torch.Tensor) and data.is_meta:
88
- return
89
- # dump data may greater than summary_list collect
90
- np_save_data(data, prefix, dump_path)
91
-
92
-
93
- def save_temp_summary(api_index, single_api_summary, path, lock):
94
- summary_path = os.path.join(path, f'summary.json')
95
- lock.acquire()
96
- with FileOpen(summary_path, "a") as f:
97
- json.dump([api_index, single_api_summary], f)
98
- f.write('\n')
99
- lock.release()
100
-
101
-
102
- def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
103
- cpu_args, cpu_kwargs = data_info.cpu_args, data_info.cpu_kwargs
104
- all_summary, func = data_info.all_summary, data_info.func
105
- npu_out_cpu, cpu_out, lock = data_info.npu_out_cpu, data_info.cpu_out, data_info.lock
106
- single_api_summary = []
107
-
108
- prefix_input = f'{run_param.aten_api}_{run_param.single_api_index}_input'
109
- prefix_output = f'{run_param.aten_api}_{run_param.single_api_index}_output'
110
-
111
- accuracy_reached = False
112
- with TimeStatistics("COMPARE OUTPUT", run_param):
113
- run_param.comparator.compare_output(prefix_output, cpu_out, npu_out_cpu, None, None)
114
-
115
- # user set dump or auto mode will dump
116
- if run_param.dump_flag or (run_param.auto_dump_flag and not accuracy_reached):
117
- with TimeStatistics("DUMP INPUT", run_param):
118
- dump_data(cpu_args, prefix_input, run_param.root_npu_path)
119
- if len(cpu_kwargs) > 0:
120
- for k, v in cpu_kwargs.items():
121
- kwargs_prefix_name = prefix_input + f'_{k}'
122
- dump_data(v, kwargs_prefix_name, run_param.root_npu_path)
123
-
124
- with TimeStatistics("DUMP OUTPUT", run_param):
125
- dump_data(cpu_out, prefix_output, run_param.root_cpu_path)
126
- dump_data(npu_out_cpu, prefix_output, run_param.root_npu_path)
127
-
128
- if run_param.process_num == 0:
129
- all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary)
130
- else:
131
- save_temp_summary(run_param.api_index - 1, single_api_summary, run_param.root_cpu_path, lock)
132
-
133
-
134
- def get_torch_func(run_param):
135
- if hasattr(torch.ops, run_param.func_namespace):
136
- ops_func = getattr(torch.ops, run_param.func_namespace)
137
- if hasattr(ops_func, run_param.aten_api):
138
- ops_aten_func = getattr(ops_func, run_param.aten_api)
139
- if hasattr(ops_aten_func, run_param.aten_api_overload_name):
140
- ops_aten_overlaod_func = getattr(ops_aten_func, run_param.aten_api_overload_name)
141
- return ops_aten_overlaod_func
142
- return None
143
-
144
-
145
- def dispatch_multiprocess(run_param, dispatch_data_info):
146
- torch_func = get_torch_func(run_param)
147
- if torch_func is None:
148
- logger.error(f'can not find suitable call api:{run_param.aten_api}')
149
- else:
150
- dispatch_data_info.func = torch_func
151
- dispatch_workflow(run_param, dispatch_data_info)
152
-
153
-
154
- def error_call(err):
155
- logger.error(f'multiprocess {err}')
156
-
1
+ import os
2
+ import json
3
+ import copy
4
+ from datetime import datetime, timezone
5
+
6
+ import torch
7
+ from msprobe.pytorch.common.log import logger
8
+ from msprobe.core.common.file_utils import FileOpen, save_npy
9
+
10
+
11
+ class DispatchRunParam:
12
+ def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator):
13
+ # static parameters are initialized by constructors, and dynamic parameters are constructed at run time
14
+ self.debug_flag = debug_flag
15
+ self.device_id = device_id
16
+ self.root_npu_path = root_npu_path
17
+ self.root_cpu_path = root_cpu_path
18
+ self.process_num = process_num
19
+ self.process_flag = False
20
+ self.func_name = None
21
+ self.func_namespace = None
22
+ self.aten_api = None
23
+ self.aten_api_overload_name = None
24
+ self.single_api_index = None
25
+ self.api_index = None
26
+ self.dump_flag = None
27
+ self.auto_dump_flag = None
28
+ self.comparator = comparator
29
+
30
+
31
+ class DisPatchDataInfo:
32
+ def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out, lock):
33
+ self.cpu_args = cpu_args
34
+ self.cpu_kwargs = cpu_kwargs
35
+ self.all_summary = all_summary
36
+ self.func = func
37
+ self.npu_out_cpu = npu_out_cpu
38
+ self.cpu_out = cpu_out
39
+ self.lock = lock
40
+
41
+
42
+ class TimeStatistics:
43
+ def __init__(self, name_tag, run_param, timeout=5):
44
+ self.debug = run_param.debug_flag
45
+ if self.debug:
46
+ self.fun = run_param.func_name
47
+ self.device = run_param.device_id
48
+ self.process = run_param.process_num
49
+ self.index = run_param.single_api_index
50
+ self.tag = name_tag
51
+ self.timeout = timeout
52
+ self.time = None
53
+
54
+ def __enter__(self):
55
+ if self.debug:
56
+ self.time = datetime.now(tz=timezone.utc)
57
+ logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
58
+ f'Id[{self.index}]')
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ if self.debug:
62
+ cost_time = datetime.now(tz=timezone.utc) - self.time
63
+ time_cost = f'Time[{self.tag}]-EXIT: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
64
+ f'Id[{self.index}], time[{cost_time}]'
65
+ hot_time_cost = "Hotspot " + time_cost
66
+
67
+ if cost_time.total_seconds() > self.timeout:
68
+ logger.info(hot_time_cost)
69
+ else:
70
+ logger.info(time_cost)
71
+
72
+
73
+ def support_basic_type(data):
74
+ if isinstance(data, (bool, int, float, torch.Tensor)):
75
+ return True
76
+ return False
77
+
78
+
79
+ def dump_data(data, prefix, dump_path):
80
+ if isinstance(data, (tuple, list)) and data:
81
+ for i, item in enumerate(data):
82
+ dump_data(item, "{}.{}".format(prefix, i), dump_path)
83
+ return
84
+ elif support_basic_type(data):
85
+ if isinstance(data, torch.Tensor) and data.is_meta:
86
+ return
87
+ # dump data may greater than summary_list collect
88
+ path = os.path.join(dump_path, f'{prefix}.npy')
89
+ save_npy(data, path)
90
+
91
+
92
+ def save_temp_summary(api_index, single_api_summary, path, lock):
93
+ summary_path = os.path.join(path, f'summary.json')
94
+ lock.acquire()
95
+ with FileOpen(summary_path, "a") as f:
96
+ json.dump([api_index, single_api_summary], f)
97
+ f.write('\n')
98
+ lock.release()
99
+
100
+
101
+ def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
102
+ cpu_args, cpu_kwargs = data_info.cpu_args, data_info.cpu_kwargs
103
+ all_summary, func = data_info.all_summary, data_info.func
104
+ npu_out_cpu, cpu_out, lock = data_info.npu_out_cpu, data_info.cpu_out, data_info.lock
105
+ single_api_summary = []
106
+
107
+ prefix_input = f'{run_param.aten_api}_{run_param.single_api_index}_input'
108
+ prefix_output = f'{run_param.aten_api}_{run_param.single_api_index}_output'
109
+
110
+ accuracy_reached = False
111
+ with TimeStatistics("COMPARE OUTPUT", run_param):
112
+ run_param.comparator.compare_output(prefix_output, cpu_out, npu_out_cpu, None, None)
113
+
114
+ # user set dump or auto mode will dump
115
+ if run_param.dump_flag or (run_param.auto_dump_flag and not accuracy_reached):
116
+ with TimeStatistics("DUMP INPUT", run_param):
117
+ dump_data(cpu_args, prefix_input, run_param.root_npu_path)
118
+ if len(cpu_kwargs) > 0:
119
+ for k, v in cpu_kwargs.items():
120
+ kwargs_prefix_name = prefix_input + f'_{k}'
121
+ dump_data(v, kwargs_prefix_name, run_param.root_npu_path)
122
+
123
+ with TimeStatistics("DUMP OUTPUT", run_param):
124
+ dump_data(cpu_out, prefix_output, run_param.root_cpu_path)
125
+ dump_data(npu_out_cpu, prefix_output, run_param.root_npu_path)
126
+
127
+ if run_param.process_num == 0:
128
+ all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary)
129
+ else:
130
+ save_temp_summary(run_param.api_index - 1, single_api_summary, run_param.root_cpu_path, lock)
131
+
132
+
133
+ def get_torch_func(run_param):
134
+ if hasattr(torch.ops, run_param.func_namespace):
135
+ ops_func = getattr(torch.ops, run_param.func_namespace)
136
+ if hasattr(ops_func, run_param.aten_api):
137
+ ops_aten_func = getattr(ops_func, run_param.aten_api)
138
+ if hasattr(ops_aten_func, run_param.aten_api_overload_name):
139
+ ops_aten_overlaod_func = getattr(ops_aten_func, run_param.aten_api_overload_name)
140
+ return ops_aten_overlaod_func
141
+ return None
142
+
143
+
144
+ def dispatch_multiprocess(run_param, dispatch_data_info):
145
+ torch_func = get_torch_func(run_param)
146
+ if torch_func is None:
147
+ logger.error(f'can not find suitable call api:{run_param.aten_api}')
148
+ else:
149
+ dispatch_data_info.func = torch_func
150
+ dispatch_workflow(run_param, dispatch_data_info)
151
+
152
+
153
+ def error_call(err):
154
+ logger.error(f'multiprocess {err}')
155
+