mindstudio-probe 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/LICENSE +201 -201
  2. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/METADATA +36 -34
  3. mindstudio_probe-1.0.4.dist-info/RECORD +276 -0
  4. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/WHEEL +1 -1
  5. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/entry_points.txt +1 -0
  6. msprobe/README.md +101 -237
  7. msprobe/{config/config.json → config.json} +49 -49
  8. msprobe/core/advisor/advisor.py +124 -124
  9. msprobe/core/advisor/advisor_const.py +59 -59
  10. msprobe/core/advisor/advisor_result.py +58 -58
  11. msprobe/core/common/const.py +341 -318
  12. msprobe/core/common/exceptions.py +99 -99
  13. msprobe/core/common/{file_check.py → file_utils.py} +478 -283
  14. msprobe/core/common/log.py +76 -69
  15. msprobe/core/common/utils.py +385 -616
  16. msprobe/core/common_config.py +85 -71
  17. msprobe/core/compare/acc_compare.py +299 -298
  18. msprobe/core/compare/check.py +95 -95
  19. msprobe/core/compare/compare_cli.py +49 -49
  20. msprobe/core/compare/highlight.py +223 -222
  21. msprobe/core/compare/multiprocessing_compute.py +149 -149
  22. msprobe/core/compare/npy_compare.py +295 -295
  23. msprobe/core/compare/utils.py +430 -429
  24. msprobe/core/data_dump/data_collector.py +154 -144
  25. msprobe/core/data_dump/data_processor/base.py +314 -293
  26. msprobe/core/data_dump/data_processor/factory.py +59 -59
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +186 -198
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +366 -389
  29. msprobe/core/data_dump/json_writer.py +96 -116
  30. msprobe/core/data_dump/scope.py +178 -178
  31. msprobe/core/grad_probe/constant.py +70 -70
  32. msprobe/core/grad_probe/grad_compare.py +171 -175
  33. msprobe/core/grad_probe/utils.py +64 -52
  34. msprobe/docs/01.installation.md +89 -0
  35. msprobe/docs/02.config_introduction.md +165 -0
  36. msprobe/docs/03.config_examples.md +247 -0
  37. msprobe/docs/04.acl_config_examples.md +76 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +198 -0
  39. msprobe/docs/06.data_dump_MindSpore.md +243 -0
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +274 -0
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +198 -0
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +68 -0
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +245 -0
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +202 -0
  45. msprobe/docs/12.overflow_check_PyTorch.md +79 -0
  46. msprobe/docs/13.overflow_check_MindSpore.md +31 -0
  47. msprobe/{pytorch/doc/parse_tool.md → docs/14.data_parse_PyTorch.md} +283 -286
  48. msprobe/docs/15.free_benchmarking_PyTorch.md +164 -0
  49. msprobe/{doc/grad_probe/grad_probe.md → docs/17.grad_probe.md} +207 -207
  50. msprobe/docs/FAQ_PyTorch.md +177 -0
  51. msprobe/docs/S02.report_free_benchmarking_validation_performance_baseline.md +146 -0
  52. msprobe/docs/img/free_benchmark_framework.png +0 -0
  53. msprobe/mindspore/__init__.py +1 -1
  54. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +254 -245
  55. msprobe/mindspore/api_accuracy_checker/api_info.py +69 -69
  56. msprobe/mindspore/api_accuracy_checker/api_runner.py +155 -151
  57. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +196 -196
  58. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +6 -0
  59. msprobe/mindspore/api_accuracy_checker/compute_element.py +238 -223
  60. msprobe/mindspore/api_accuracy_checker/main.py +8 -15
  61. msprobe/mindspore/api_accuracy_checker/type_mapping.py +113 -113
  62. msprobe/mindspore/api_accuracy_checker/utils.py +79 -62
  63. msprobe/mindspore/cell_processor.py +34 -34
  64. msprobe/mindspore/common/const.py +106 -87
  65. msprobe/mindspore/common/log.py +37 -37
  66. msprobe/mindspore/common/utils.py +81 -57
  67. msprobe/mindspore/compare/distributed_compare.py +75 -75
  68. msprobe/mindspore/compare/ms_compare.py +219 -117
  69. msprobe/mindspore/compare/ms_graph_compare.py +348 -317
  70. msprobe/mindspore/compare/ms_to_pt_api.yaml +399 -399
  71. msprobe/mindspore/debugger/debugger_config.py +66 -74
  72. msprobe/mindspore/debugger/precision_debugger.py +126 -107
  73. msprobe/mindspore/dump/dump_tool_factory.py +35 -35
  74. msprobe/mindspore/dump/hook_cell/api_registry.py +118 -104
  75. msprobe/mindspore/dump/hook_cell/hook_cell.py +55 -53
  76. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +922 -925
  77. msprobe/mindspore/dump/hook_cell/wrap_api.py +113 -0
  78. msprobe/mindspore/dump/jit_dump.py +72 -56
  79. msprobe/mindspore/dump/kernel_graph_dump.py +59 -60
  80. msprobe/mindspore/dump/kernel_kbyk_dump.py +64 -65
  81. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +116 -116
  82. msprobe/mindspore/free_benchmark/common/config.py +12 -12
  83. msprobe/mindspore/free_benchmark/common/handler_params.py +17 -17
  84. msprobe/mindspore/free_benchmark/common/utils.py +71 -71
  85. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +842 -842
  86. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +43 -42
  87. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +107 -107
  88. msprobe/mindspore/free_benchmark/handler/base_handler.py +90 -90
  89. msprobe/mindspore/free_benchmark/handler/check_handler.py +41 -41
  90. msprobe/mindspore/free_benchmark/handler/fix_handler.py +36 -36
  91. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -21
  92. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +67 -67
  93. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +21 -21
  94. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +63 -63
  95. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +51 -0
  96. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +35 -34
  97. msprobe/mindspore/free_benchmark/perturbation/no_change.py +12 -12
  98. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +29 -27
  99. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +33 -33
  100. msprobe/mindspore/grad_probe/global_context.py +90 -91
  101. msprobe/mindspore/grad_probe/grad_analyzer.py +231 -231
  102. msprobe/mindspore/grad_probe/grad_monitor.py +27 -27
  103. msprobe/mindspore/grad_probe/grad_stat_csv.py +131 -131
  104. msprobe/mindspore/grad_probe/hook.py +94 -92
  105. msprobe/mindspore/grad_probe/utils.py +29 -28
  106. msprobe/mindspore/ms_config.py +128 -126
  107. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +44 -45
  108. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +34 -34
  109. msprobe/mindspore/runtime.py +4 -4
  110. msprobe/mindspore/service.py +378 -354
  111. msprobe/mindspore/task_handler_factory.py +24 -24
  112. msprobe/msprobe.py +105 -107
  113. msprobe/pytorch/__init__.py +3 -3
  114. msprobe/pytorch/api_accuracy_checker/common/config.py +53 -55
  115. msprobe/pytorch/api_accuracy_checker/common/utils.py +214 -165
  116. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +213 -213
  117. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +606 -581
  118. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +132 -132
  119. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -390
  120. msprobe/pytorch/api_accuracy_checker/compare/compare.py +386 -381
  121. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +73 -73
  122. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +245 -244
  123. msprobe/pytorch/api_accuracy_checker/config.yaml +10 -10
  124. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +335 -332
  125. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +200 -199
  126. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +133 -134
  127. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +592 -581
  128. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +70 -74
  129. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +7 -4
  130. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +197 -202
  131. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +325 -324
  132. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +204 -204
  133. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +219 -218
  134. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +10 -10
  135. msprobe/pytorch/bench_functions/__init__.py +15 -15
  136. msprobe/pytorch/bench_functions/apply_adam_w.py +28 -28
  137. msprobe/pytorch/bench_functions/confusion_transpose.py +19 -19
  138. msprobe/pytorch/bench_functions/fast_gelu.py +55 -55
  139. msprobe/pytorch/bench_functions/layer_norm_eval.py +6 -6
  140. msprobe/pytorch/bench_functions/linear.py +12 -12
  141. msprobe/pytorch/bench_functions/matmul_backward.py +48 -48
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +509 -421
  143. msprobe/pytorch/bench_functions/rms_norm.py +15 -15
  144. msprobe/pytorch/bench_functions/rotary_mul.py +52 -52
  145. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +26 -26
  146. msprobe/pytorch/bench_functions/swiglu.py +55 -55
  147. msprobe/pytorch/common/__init__.py +2 -2
  148. msprobe/pytorch/common/compare_script.template +14 -14
  149. msprobe/pytorch/common/log.py +20 -31
  150. msprobe/pytorch/common/parse_json.py +39 -39
  151. msprobe/pytorch/common/utils.py +305 -300
  152. msprobe/pytorch/compare/distributed_compare.py +66 -66
  153. msprobe/pytorch/compare/mapping.yaml +607 -607
  154. msprobe/pytorch/compare/match.py +34 -33
  155. msprobe/pytorch/compare/pt_compare.py +50 -40
  156. msprobe/pytorch/debugger/debugger_config.py +95 -95
  157. msprobe/pytorch/debugger/precision_debugger.py +125 -125
  158. msprobe/pytorch/free_benchmark/__init__.py +8 -8
  159. msprobe/pytorch/free_benchmark/common/constant.py +70 -70
  160. msprobe/pytorch/free_benchmark/common/counter.py +71 -71
  161. msprobe/pytorch/free_benchmark/common/enums.py +37 -37
  162. msprobe/pytorch/free_benchmark/common/params.py +129 -129
  163. msprobe/pytorch/free_benchmark/common/utils.py +102 -102
  164. msprobe/pytorch/free_benchmark/compare/grad_saver.py +179 -179
  165. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -104
  166. msprobe/pytorch/free_benchmark/main.py +105 -105
  167. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -13
  168. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -41
  169. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -90
  170. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -104
  171. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -63
  172. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -68
  173. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -28
  174. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -45
  175. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -19
  176. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +217 -217
  177. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -39
  178. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +23 -23
  179. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +30 -30
  180. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -170
  181. msprobe/pytorch/function_factory.py +76 -75
  182. msprobe/pytorch/functional/dump_module.py +39 -39
  183. msprobe/pytorch/grad_probe/grad_monitor.py +91 -90
  184. msprobe/pytorch/grad_probe/grad_stat_csv.py +128 -128
  185. msprobe/pytorch/hook_module/api_registry.py +161 -161
  186. msprobe/pytorch/hook_module/hook_module.py +120 -120
  187. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1879 -1877
  188. msprobe/pytorch/hook_module/utils.py +30 -29
  189. msprobe/pytorch/hook_module/wrap_aten.py +110 -110
  190. msprobe/pytorch/hook_module/wrap_distributed.py +78 -78
  191. msprobe/pytorch/hook_module/wrap_functional.py +105 -105
  192. msprobe/pytorch/hook_module/wrap_npu_custom.py +93 -84
  193. msprobe/pytorch/hook_module/wrap_tensor.py +71 -71
  194. msprobe/pytorch/hook_module/wrap_torch.py +86 -86
  195. msprobe/pytorch/hook_module/wrap_vf.py +62 -62
  196. msprobe/pytorch/module_processer.py +138 -138
  197. msprobe/pytorch/online_dispatch/__init__.py +20 -20
  198. msprobe/pytorch/online_dispatch/compare.py +236 -236
  199. msprobe/pytorch/online_dispatch/dispatch.py +271 -271
  200. msprobe/pytorch/online_dispatch/dump_compare.py +155 -156
  201. msprobe/pytorch/online_dispatch/single_compare.py +391 -391
  202. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +49 -49
  203. msprobe/pytorch/online_dispatch/utils.py +130 -146
  204. msprobe/pytorch/parse.py +4 -4
  205. msprobe/pytorch/parse_tool/cli.py +32 -32
  206. msprobe/pytorch/parse_tool/lib/compare.py +260 -271
  207. msprobe/pytorch/parse_tool/lib/config.py +52 -52
  208. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -31
  209. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -102
  210. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -54
  211. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -158
  212. msprobe/pytorch/parse_tool/lib/utils.py +316 -321
  213. msprobe/pytorch/parse_tool/lib/visualization.py +85 -91
  214. msprobe/pytorch/pt_config.py +188 -187
  215. msprobe/pytorch/service.py +246 -252
  216. mindstudio_probe-1.0.3.dist-info/RECORD +0 -272
  217. msprobe/config/README.md +0 -539
  218. msprobe/mindspore/doc/compare.md +0 -58
  219. msprobe/mindspore/doc/dump.md +0 -217
  220. msprobe/mindspore/dump/hook_cell/wrap_functional.py +0 -91
  221. msprobe/mindspore/dump/hook_cell/wrap_tensor.py +0 -63
  222. msprobe/pytorch/doc/FAQ.md +0 -193
  223. msprobe/pytorch/doc/api_accuracy_checker.md +0 -313
  224. msprobe/pytorch/doc/api_accuracy_checker_online.md +0 -187
  225. msprobe/pytorch/doc/dump.md +0 -260
  226. msprobe/pytorch/doc/msprobe/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -182
  227. msprobe/pytorch/doc/ptdbg_ascend_compare.md +0 -240
  228. msprobe/pytorch/doc/ptdbg_ascend_overview.md +0 -68
  229. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +0 -381
  230. msprobe/pytorch/doc/run_overflow_check.md +0 -25
  231. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +0 -90
  232. msprobe/pytorch/doc//321/206/320/247/320/260/321/206/320/260/320/227/321/206/320/255/320/226/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/205/320/254/342/225/221/321/206/320/251/320/277/321/211/320/272/320/234/321/210/320/277/320/221/321/205/320/242/320/234/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +0 -151
  233. {mindstudio_probe-1.0.3.dist-info → mindstudio_probe-1.0.4.dist-info}/top_level.txt +0 -0
  234. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_1.png +0 -0
  235. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_2.png +0 -0
  236. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_3.png +0 -0
  237. /msprobe/{pytorch/doc → docs}/img/BLOOM-7B_4.png +0 -0
  238. /msprobe/{pytorch/doc → docs}/img/GPT-3_1.png +0 -0
  239. /msprobe/{pytorch/doc → docs}/img/GPT-3_2.png +0 -0
  240. /msprobe/{pytorch/doc → docs}/img/GPT-3_3.png +0 -0
  241. /msprobe/{pytorch/doc → docs}/img/GPT-3_4.png +0 -0
  242. /msprobe/{pytorch/doc → docs}/img/GPT-3_5.png +0 -0
  243. /msprobe/{pytorch/doc → docs}/img/GPT-3_6.png +0 -0
  244. /msprobe/{pytorch/doc → docs}/img/GPT-3_7.png +0 -0
  245. /msprobe/{pytorch/doc → docs}/img/GPT-3_8.png +0 -0
  246. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_1.png +0 -0
  247. /msprobe/{pytorch/doc → docs}/img/YOLOV5S_2.png +0 -0
  248. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_details.png +0 -0
  249. /msprobe/{pytorch/doc → docs}/img/accuracy_checking_result.png +0 -0
  250. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_details.png +0 -0
  251. /msprobe/{pytorch/doc → docs}/img/api_precision_compare_result.png +0 -0
  252. /msprobe/{pytorch/doc → docs}/img/auto_analyze_log.png +0 -0
  253. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl.png +0 -0
  254. /msprobe/{pytorch/doc → docs}/img/compare_result_pkl_md5.png.png +0 -0
  255. /msprobe/{pytorch/doc → docs}/img/cpu_info.png +0 -0
  256. /msprobe/{config → docs}/img/free_benchmark.png +0 -0
  257. /msprobe/{doc/grad_probe/img/image-1.png → docs/img/grad_probe_image-1.png} +0 -0
  258. /msprobe/{doc/grad_probe/img/image-2.png → docs/img/grad_probe_image-2.png} +0 -0
  259. /msprobe/{doc/grad_probe/img/image-3.png → docs/img/grad_probe_image-3.png} +0 -0
  260. /msprobe/{doc/grad_probe/img/image-4.png → docs/img/grad_probe_image-4.png} +0 -0
  261. /msprobe/{doc/grad_probe/img/image.png → docs/img/grad_probe_image.png} +0 -0
  262. /msprobe/{pytorch/doc → docs}/img/module_compare.png +0 -0
@@ -1,117 +1,219 @@
1
- import os.path
2
- from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
3
- task_dumppath_get, load_yaml, load_npy
4
- from msprobe.core.common.file_check import create_directory
5
- from msprobe.core.common.const import Const
6
- from msprobe.core.common.log import logger
7
- from msprobe.core.common.exceptions import FileCheckException
8
- from msprobe.core.compare.acc_compare import Comparator
9
- from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
10
-
11
-
12
- class MSComparator(Comparator):
13
- def __init__(self, cell_mapping=None, api_mapping=None):
14
- self.frame_name = MSComparator.__name__
15
- self.cell_mapping = cell_mapping
16
- self.api_mapping = api_mapping
17
- self.cross_frame = cell_mapping is not None or api_mapping is not None
18
- self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
19
- self.api_mapping_dict = {}
20
- if api_mapping is not None:
21
- self.ms_to_pt_mapping = self.load_internal_api()
22
-
23
- def load_internal_api(self):
24
- cur_path = os.path.dirname(os.path.realpath(__file__))
25
- yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
26
- return load_yaml(yaml_path)
27
-
28
- def load_mapping_file(self, mapping_file):
29
- if isinstance(mapping_file, str):
30
- mapping_dict = load_yaml(mapping_file)
31
- else:
32
- mapping_dict = {}
33
- return mapping_dict
34
-
35
- def process_cell_mapping(self, npu_op_name):
36
- npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
37
- if self.cell_mapping_dict:
38
- for index, op_name in enumerate(npu_op_name):
39
- # get cell name & class name from op_name
40
- # Cell.fc1.Dense.forward.0.input.0
41
- cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
42
- if cell_name in self.cell_mapping_dict:
43
- npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
44
- return npu_op_name
45
-
46
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
47
- npu_op_name = npu_dict["op_name"].copy()
48
- bench_op_name = bench_dict["op_name"].copy()
49
-
50
- if self.api_mapping is not None:
51
- npu_op_name = self.process_api_mapping(npu_op_name, bench_op_name)
52
- if self.cell_mapping is not None:
53
- npu_op_name = self.process_cell_mapping(npu_op_name)
54
-
55
- struct_match = check_struct_match(npu_dict, bench_dict, cross_frame=self.cross_frame)
56
- if not fuzzy_match:
57
- return npu_op_name == bench_op_name and struct_match
58
- is_match = True
59
- try:
60
- is_match = fuzzy_check_op(npu_op_name, bench_op_name)
61
- except Exception as err:
62
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
63
- is_match = False
64
- return is_match and struct_match
65
-
66
- def read_npy_data(self, dir_path, file_name, load_pt_file=False):
67
- data_path = os.path.join(dir_path, file_name)
68
- if load_pt_file:
69
- import torch
70
- from msprobe.pytorch.common.utils import load_pt
71
- data_value = load_pt(data_path).detach()
72
- if data_value.dtype == torch.bfloat16:
73
- data_value = data_value.to(torch.float32)
74
- data_value = data_value.numpy()
75
- else:
76
- data_value = load_npy(data_path)
77
- return data_value
78
-
79
- def api_replace(self, npu_op_name, target, para):
80
- for idx, _ in enumerate(npu_op_name):
81
- npu_op_name[idx] = npu_op_name[idx].replace(target, para)
82
- return npu_op_name
83
-
84
- def process_api_mapping(self, npu_op_name, bench_op_name):
85
- # get api name & class name from op_name
86
- # Functional.addcmul.0.forward.input.0
87
- ms_api_name = npu_op_name[0].rsplit(Const.SEP, 4)[0]
88
- pt_api_name = bench_op_name[0].rsplit(Const.SEP, 4)[0]
89
- class_name = ms_api_name.split(Const.SEP)[0]
90
- if class_name == "Mint":
91
- return self.api_replace(npu_op_name, "Mint", "Torch")
92
- elif class_name == "MintFunctional":
93
- return self.api_replace(npu_op_name, "MintFunctional", "Functional")
94
- elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
95
- return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
96
- else:
97
- return npu_op_name
98
-
99
-
100
- def ms_compare(input_param, output_path, **kwargs):
101
- try:
102
- stack_mode = kwargs.get('stack_mode', False)
103
- auto_analyze = kwargs.get('auto_analyze', True)
104
- fuzzy_match = kwargs.get('fuzzy_match', False)
105
- cell_mapping = kwargs.get('cell_mapping', None)
106
- api_mapping = kwargs.get('api_mapping', None)
107
- summary_compare, md5_compare = task_dumppath_get(input_param)
108
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
109
- create_directory(output_path)
110
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
111
- except (CompareException, FileCheckException) as error:
112
- logger.error('Compare failed. Please check the arguments and do it again!')
113
- raise CompareException(error.code) from error
114
- ms_comparator = MSComparator(cell_mapping, api_mapping)
115
- ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
116
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
117
- md5_compare=md5_compare)
1
+ import os
2
+ import copy
3
+ from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
4
+ task_dumppath_get
5
+ from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy
6
+ from msprobe.core.common.const import Const, CompareConst
7
+ from msprobe.core.common.log import logger
8
+ from msprobe.core.common.exceptions import FileCheckException
9
+ from msprobe.core.compare.acc_compare import Comparator
10
+ from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
11
+
12
+
13
+ class MSComparator(Comparator):
14
+ def __init__(self, cell_mapping=None, api_mapping=None):
15
+ self.frame_name = MSComparator.__name__
16
+ self.cell_mapping = cell_mapping
17
+ self.api_mapping = api_mapping
18
+ self.cross_frame = cell_mapping is not None or api_mapping is not None
19
+ self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
20
+ self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
21
+ if api_mapping is not None:
22
+ self.ms_to_pt_mapping = self.load_internal_api()
23
+
24
+ def load_internal_api(self):
25
+ cur_path = os.path.dirname(os.path.realpath(__file__))
26
+ yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
27
+ return load_yaml(yaml_path)
28
+
29
+ def load_mapping_file(self, mapping_file):
30
+ if isinstance(mapping_file, str):
31
+ mapping_dict = load_yaml(mapping_file)
32
+ else:
33
+ mapping_dict = {}
34
+ return mapping_dict
35
+
36
+ def process_cell_mapping(self, npu_op_name):
37
+ npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
38
+ if self.cell_mapping_dict:
39
+ for index, op_name in enumerate(npu_op_name):
40
+ # get cell name & class name from op_name
41
+ # Cell.fc1.Dense.forward.0.input.0
42
+ cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
43
+ if cell_name in self.cell_mapping_dict:
44
+ npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
45
+ return npu_op_name
46
+
47
+ def check_op(self, npu_dict, bench_dict, fuzzy_match):
48
+ npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
49
+ npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
50
+ if self.cell_mapping is not None:
51
+ npu_op_name = self.process_cell_mapping(npu_op_name)
52
+ if self.api_mapping is not None:
53
+ npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
54
+ if isinstance(self.api_mapping, str):
55
+ npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new, bench_dict_new)
56
+ if target_dict:
57
+ bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
58
+ npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
59
+ struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
60
+ if not fuzzy_match:
61
+ return npu_op_name == bench_op_name and struct_match
62
+ is_match = True
63
+ try:
64
+ is_match = fuzzy_check_op(npu_op_name, bench_op_name)
65
+ except Exception as err:
66
+ logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
67
+ is_match = False
68
+ return is_match and struct_match
69
+
70
+ def read_npy_data(self, dir_path, file_name, load_pt_file=False):
71
+ data_path = os.path.join(dir_path, file_name)
72
+ if load_pt_file:
73
+ import torch
74
+ from msprobe.pytorch.common.utils import load_pt
75
+ data_value = load_pt(data_path).detach()
76
+ if data_value.dtype == torch.bfloat16:
77
+ data_value = data_value.to(torch.float32)
78
+ data_value = data_value.numpy()
79
+ else:
80
+ data_value = load_npy(data_path)
81
+ return data_value
82
+
83
+ def api_replace(self, npu_op_name, target, para):
84
+ for idx, _ in enumerate(npu_op_name):
85
+ npu_op_name[idx] = npu_op_name[idx].replace(target, para)
86
+ return npu_op_name
87
+
88
+ def process_internal_api_mapping(self, npu_op_name, bench_op_name):
89
+ # get api name & class name from op_name
90
+ # Functional.addcmul.0.forward.input.0
91
+ npu_op_name, bench_op_name = npu_op_name.copy(), bench_op_name.copy()
92
+ ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
93
+ pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
94
+ class_name = ms_api_name.split(Const.SEP)[0]
95
+ if class_name == "Mint":
96
+ return self.api_replace(npu_op_name, "Mint", "Torch")
97
+ elif class_name == "MintFunctional":
98
+ return self.api_replace(npu_op_name, "MintFunctional", "Functional")
99
+ elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
100
+ return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
101
+ else:
102
+ return npu_op_name
103
+
104
+ def remove_element(self, op_name, struct, summary, idx):
105
+ del op_name[idx]
106
+ del struct[idx]
107
+ del summary[idx]
108
+
109
+ def get_api_name(self, api_list):
110
+ return api_list[0] + Const.SEP + api_list[1]
111
+
112
+ def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
113
+ """
114
+ Transform user mapping API based on new NPU and benchmark dictionaries.
115
+ Parameters:
116
+ new_npu_dict (dict): New NPU operation dictionary.
117
+ new_bench_dict (dict): New benchmark operation dictionary.
118
+ Returns:
119
+ tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
120
+ """
121
+ npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
122
+ npu_struct_in, bench_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT), new_bench_dict.get(CompareConst.INPUT_STRUCT)
123
+ npu_struct_out, bench_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT), new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
124
+ npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
125
+ npu_in_len, bench_in_len, npu_out_len, bench_out_len = len(npu_struct_in), len(bench_struct_in), len(npu_struct_out), len(bench_struct_out)
126
+ ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
127
+ ms_api_name = self.get_api_name(ms_api_list)
128
+ pt_api_name = self.get_api_name(pt_api_list)
129
+ target_dict = {}
130
+ for api_dict in self.api_mapping_dict:
131
+ if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
132
+ ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
133
+ ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
134
+ if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
135
+ logger.warning("The user-defined mapping table is incorrect, make sure that the number of parameters is equal" )
136
+ break
137
+ ms_out_list = api_dict.get("ms_output", [])
138
+ for idx in reversed(range(npu_out_len)):
139
+ if idx not in ms_out_list:
140
+ del npu_struct_out[idx]
141
+ del npu_summary[idx + npu_in_len]
142
+ del npu_op_name[idx + npu_in_len]
143
+ pt_out_list = api_dict.get("pt_output", [])
144
+ for idx in reversed(range(bench_out_len)):
145
+ if idx not in pt_out_list:
146
+ del bench_struct_out[idx]
147
+ del bench_summary[idx + bench_in_len]
148
+ del bench_op_name[idx + bench_in_len]
149
+ ms_para_list = api_dict.get("ms_args", [])
150
+ for idx in reversed(range(npu_in_len)):
151
+ if idx not in ms_para_list:
152
+ self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
153
+ pt_para_list = api_dict.get("pt_args", [])
154
+ for idx in reversed(range(bench_in_len)):
155
+ if idx not in pt_para_list:
156
+ self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
157
+ npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
158
+ npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
159
+ target_dict = api_dict
160
+ break
161
+ if target_dict:
162
+ new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in, CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
163
+ new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
164
+ return new_npu_dict, new_bench_dict, target_dict
165
+
166
+ def para_sequence_update(self, npu_op_name, bench_op_name):
167
+ for idx, _ in enumerate(npu_op_name):
168
+ bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
169
+ if len(bench_op_name_list) != 0:
170
+ npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
171
+ return npu_op_name
172
+
173
+ def reconstitution_bench_dict(self, npu_dict, del_bench_dict, api_dict):
174
+ ms_user_args_list = api_dict.get("ms_args", [])
175
+ ms_user_output_list = api_dict.get("ms_output", [])
176
+ npu_struct_in = npu_dict.get(CompareConst.INPUT_STRUCT)
177
+ npu_struct_out = npu_dict.get(CompareConst.OUTPUT_STRUCT)
178
+ npu_in_len = len(npu_struct_in)
179
+ npu_out_len = len(npu_struct_out)
180
+ if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
181
+ return del_bench_dict
182
+ ms_input_args_list = [i for i in range(npu_in_len)]
183
+ input_sub_list =list(set(ms_input_args_list) - set(ms_user_args_list))
184
+ ms_output_args_list = [i for i in range(npu_out_len)]
185
+ output_sub_list =list(set(ms_output_args_list) - set(ms_user_output_list))
186
+ bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
187
+ bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
188
+ bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
189
+ bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
190
+ for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
191
+ bench_op_name.insert(idx, CompareConst.NAN)
192
+ bench_struct_in.insert(idx, CompareConst.NAN)
193
+ bench_summary.insert(idx, CompareConst.NAN)
194
+ for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
195
+ bench_op_name.insert(npu_in_len + idx, CompareConst.NAN)
196
+ bench_struct_out.insert(idx, CompareConst.NAN)
197
+ bench_summary.insert(npu_in_len + idx, CompareConst.NAN)
198
+ del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
199
+ return del_bench_dict
200
+
201
+
202
+ def ms_compare(input_param, output_path, **kwargs):
203
+ try:
204
+ stack_mode = kwargs.get('stack_mode', False)
205
+ auto_analyze = kwargs.get('auto_analyze', True)
206
+ fuzzy_match = kwargs.get('fuzzy_match', False)
207
+ cell_mapping = kwargs.get('cell_mapping', None)
208
+ api_mapping = kwargs.get('api_mapping', None)
209
+ summary_compare, md5_compare = task_dumppath_get(input_param)
210
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
211
+ create_directory(output_path)
212
+ check_compare_param(input_param, output_path, summary_compare, md5_compare)
213
+ except (CompareException, FileCheckException) as error:
214
+ logger.error('Compare failed. Please check the arguments and do it again!')
215
+ raise CompareException(error.code) from error
216
+ ms_comparator = MSComparator(cell_mapping, api_mapping)
217
+ ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
218
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
219
+ md5_compare=md5_compare)