mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,95 +1,95 @@
1
- from msprobe.core.common.log import logger
2
- from msprobe.core.compare.utils import rename_api
3
-
4
-
5
- dtype_mapping = {
6
- "Int8": "torch.int8",
7
- "UInt8": "torch.uint8",
8
- "Int16": "torch.int16",
9
- "UInt16": "torch.uint16",
10
- "Int32": "torch.int32",
11
- "UInt32": "torch.uint32",
12
- "Int64": "torch.int64",
13
- "UInt64": "torch.uint64",
14
- "Float16": "torch.float16",
15
- "Float32": "torch.float32",
16
- "Float64": "torch.float64",
17
- "Bool": "torch.bool",
18
- "BFloat16": "torch.bfloat16",
19
- "Complex64": "torch.complex64",
20
- "Complex128": "torch.complex128"
21
- }
22
-
23
-
24
- def check_struct_match(npu_dict, bench_dict, cross_frame=False):
25
- npu_struct_in = npu_dict.get("input_struct")
26
- bench_struct_in = bench_dict.get("input_struct")
27
- npu_struct_out = npu_dict.get("output_struct")
28
- bench_struct_out = bench_dict.get("output_struct")
29
-
30
- if cross_frame:
31
- npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
32
- npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
33
- is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
34
- if not is_match:
35
- if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
36
- return False
37
- struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
38
- struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
39
- is_match = struct_in_is_match and struct_out_is_match
40
- return is_match
41
-
42
-
43
- def check_type_shape_match(npu_struct, bench_struct):
44
- shape_type_match = False
45
- for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
46
- npu_type = npu_type_shape[0]
47
- npu_shape = npu_type_shape[1]
48
- bench_type = bench_type_shape[0]
49
- bench_shape = bench_type_shape[1]
50
- shape_match = npu_shape == bench_shape
51
- type_match = npu_type == bench_type
52
- if not type_match:
53
- ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]]
54
- torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
55
- ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]
56
- if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type):
57
- type_match = True
58
- else:
59
- type_match = False
60
- shape_type_match = shape_match and type_match
61
- if not shape_type_match:
62
- return False
63
- return shape_type_match
64
-
65
-
66
- def check_graph_mode(a_op_name, b_op_name):
67
- if "Aten" in a_op_name and "Aten" not in b_op_name:
68
- return True
69
- if "Aten" not in a_op_name and "Aten" in b_op_name:
70
- return True
71
- return False
72
-
73
-
74
- def fuzzy_check_op(npu_name_list, bench_name_list):
75
- if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
76
- return False
77
- is_match = True
78
- for npu_name, bench_name in zip(npu_name_list, bench_name_list):
79
- is_match = fuzzy_check_name(npu_name, bench_name)
80
- if not is_match:
81
- break
82
- return is_match
83
-
84
-
85
- def fuzzy_check_name(npu_name, bench_name):
86
- if "forward" in npu_name and "forward" in bench_name:
87
- is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
88
- elif "backward" in npu_name and "backward" in bench_name:
89
- is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
90
- else:
91
- is_match = npu_name == bench_name
92
- return is_match
93
-
94
-
95
-
1
+ from msprobe.core.common.log import logger
2
+ from msprobe.core.compare.utils import rename_api
3
+
4
+
5
+ dtype_mapping = {
6
+ "Int8": "torch.int8",
7
+ "UInt8": "torch.uint8",
8
+ "Int16": "torch.int16",
9
+ "UInt16": "torch.uint16",
10
+ "Int32": "torch.int32",
11
+ "UInt32": "torch.uint32",
12
+ "Int64": "torch.int64",
13
+ "UInt64": "torch.uint64",
14
+ "Float16": "torch.float16",
15
+ "Float32": "torch.float32",
16
+ "Float64": "torch.float64",
17
+ "Bool": "torch.bool",
18
+ "BFloat16": "torch.bfloat16",
19
+ "Complex64": "torch.complex64",
20
+ "Complex128": "torch.complex128"
21
+ }
22
+
23
+
24
+ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
25
+ npu_struct_in = npu_dict.get("input_struct")
26
+ bench_struct_in = bench_dict.get("input_struct")
27
+ npu_struct_out = npu_dict.get("output_struct")
28
+ bench_struct_out = bench_dict.get("output_struct")
29
+
30
+ if cross_frame:
31
+ npu_struct_in = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_in]
32
+ npu_struct_out = [(dtype_mapping.get(item[0], item[0]), item[1]) for item in npu_struct_out]
33
+ is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out
34
+ if not is_match:
35
+ if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
36
+ return False
37
+ struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
38
+ struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
39
+ is_match = struct_in_is_match and struct_out_is_match
40
+ return is_match
41
+
42
+
43
+ def check_type_shape_match(npu_struct, bench_struct):
44
+ shape_type_match = False
45
+ for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
46
+ npu_type = npu_type_shape[0]
47
+ npu_shape = npu_type_shape[1]
48
+ bench_type = bench_type_shape[0]
49
+ bench_shape = bench_type_shape[1]
50
+ shape_match = npu_shape == bench_shape
51
+ type_match = npu_type == bench_type
52
+ if not type_match:
53
+ ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]]
54
+ torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
55
+ ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]
56
+ if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type):
57
+ type_match = True
58
+ else:
59
+ type_match = False
60
+ shape_type_match = shape_match and type_match
61
+ if not shape_type_match:
62
+ return False
63
+ return shape_type_match
64
+
65
+
66
+ def check_graph_mode(a_op_name, b_op_name):
67
+ if "Aten" in a_op_name and "Aten" not in b_op_name:
68
+ return True
69
+ if "Aten" not in a_op_name and "Aten" in b_op_name:
70
+ return True
71
+ return False
72
+
73
+
74
+ def fuzzy_check_op(npu_name_list, bench_name_list):
75
+ if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list):
76
+ return False
77
+ is_match = True
78
+ for npu_name, bench_name in zip(npu_name_list, bench_name_list):
79
+ is_match = fuzzy_check_name(npu_name, bench_name)
80
+ if not is_match:
81
+ break
82
+ return is_match
83
+
84
+
85
+ def fuzzy_check_name(npu_name, bench_name):
86
+ if "forward" in npu_name and "forward" in bench_name:
87
+ is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
88
+ elif "backward" in npu_name and "backward" in bench_name:
89
+ is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
90
+ else:
91
+ is_match = npu_name == bench_name
92
+ return is_match
93
+
94
+
95
+
@@ -1,49 +1,49 @@
1
- import json
2
- from msprobe.core.common.file_check import FileOpen, check_file_type
3
- from msprobe.core.common.const import FileCheckConst, Const
4
- from msprobe.core.common.utils import CompareException
5
- from msprobe.core.common.log import logger
6
-
7
-
8
- def compare_cli(args):
9
- with FileOpen(args.input_path, "r") as file:
10
- input_param = json.load(file)
11
- npu_path = input_param.get("npu_path", None)
12
- bench_path = input_param.get("bench_path", None)
13
- frame_name = args.framework
14
- auto_analyze = not args.compare_only
15
- if frame_name == Const.PT_FRAMEWORK:
16
- from msprobe.pytorch.compare.pt_compare import compare
17
- from msprobe.pytorch.compare.distributed_compare import compare_distributed
18
- else:
19
- from msprobe.mindspore.compare.ms_compare import ms_compare
20
- from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
21
- if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
22
- input_param["npu_json_path"] = input_param.pop("npu_path")
23
- input_param["bench_json_path"] = input_param.pop("bench_path")
24
- input_param["stack_json_path"] = input_param.pop("stack_path")
25
- if frame_name == Const.PT_FRAMEWORK:
26
- compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
27
- fuzzy_match=args.fuzzy_match)
28
- else:
29
- kwargs = {
30
- "stack_mode": args.stack_mode,
31
- "auto_analyze": auto_analyze,
32
- "fuzzy_match": args.fuzzy_match,
33
- "cell_mapping": args.cell_mapping,
34
- "api_mapping": args.api_mapping,
35
- }
36
-
37
- ms_compare(input_param, args.output_path, **kwargs)
38
- elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
39
- kwargs = {"stack_mode": args.stack_mode, "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match}
40
- if input_param.get("rank_id") is not None:
41
- ms_graph_compare(input_param, args.output_path)
42
- return
43
- if frame_name == Const.PT_FRAMEWORK:
44
- compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
45
- else:
46
- ms_compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
47
- else:
48
- logger.error("The npu_path and bench_path need to be of the same type.")
49
- raise CompareException(CompareException.INVALID_COMPARE_MODE)
1
+ import json
2
+ from msprobe.core.common.file_utils import FileOpen, check_file_type
3
+ from msprobe.core.common.const import FileCheckConst, Const
4
+ from msprobe.core.common.utils import CompareException
5
+ from msprobe.core.common.log import logger
6
+
7
+
8
+ def compare_cli(args):
9
+ with FileOpen(args.input_path, "r") as file:
10
+ input_param = json.load(file)
11
+ npu_path = input_param.get("npu_path", None)
12
+ bench_path = input_param.get("bench_path", None)
13
+ frame_name = args.framework
14
+ auto_analyze = not args.compare_only
15
+ if frame_name == Const.PT_FRAMEWORK:
16
+ from msprobe.pytorch.compare.pt_compare import compare
17
+ from msprobe.pytorch.compare.distributed_compare import compare_distributed
18
+ else:
19
+ from msprobe.mindspore.compare.ms_compare import ms_compare
20
+ from msprobe.mindspore.compare.distributed_compare import ms_compare_distributed, ms_graph_compare
21
+ if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE:
22
+ input_param["npu_json_path"] = input_param.pop("npu_path")
23
+ input_param["bench_json_path"] = input_param.pop("bench_path")
24
+ input_param["stack_json_path"] = input_param.pop("stack_path")
25
+ if frame_name == Const.PT_FRAMEWORK:
26
+ compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
27
+ fuzzy_match=args.fuzzy_match)
28
+ else:
29
+ kwargs = {
30
+ "stack_mode": args.stack_mode,
31
+ "auto_analyze": auto_analyze,
32
+ "fuzzy_match": args.fuzzy_match,
33
+ "cell_mapping": args.cell_mapping,
34
+ "api_mapping": args.api_mapping,
35
+ }
36
+
37
+ ms_compare(input_param, args.output_path, **kwargs)
38
+ elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
39
+ kwargs = {"stack_mode": args.stack_mode, "auto_analyze": auto_analyze, "fuzzy_match": args.fuzzy_match}
40
+ if input_param.get("rank_id") is not None:
41
+ ms_graph_compare(input_param, args.output_path)
42
+ return
43
+ if frame_name == Const.PT_FRAMEWORK:
44
+ compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
45
+ else:
46
+ ms_compare_distributed(npu_path, bench_path, args.output_path, **kwargs)
47
+ else:
48
+ logger.error("The npu_path and bench_path need to be of the same type.")
49
+ raise CompareException(CompareException.INVALID_COMPARE_MODE)