mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,29 +1,170 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
- import copy
3
- from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
4
- task_dumppath_get
5
- from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy
6
- from msprobe.core.common.const import Const, CompareConst
7
- from msprobe.core.common.log import logger
17
+ import re
18
+
19
+ from collections import defaultdict
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+ from msprobe.core.common.const import CompareConst, Const
8
25
  from msprobe.core.common.exceptions import FileCheckException
26
+ from msprobe.core.common.file_utils import (FileOpen, create_directory, load_json,
27
+ load_npy, load_yaml)
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.utils import (CompareException, check_compare_param,
30
+ check_configuration_param,
31
+ get_dump_mode, set_dump_path, check_op_str_pattern_valid)
32
+ from msprobe.core.compare.check import dtype_mapping
9
33
  from msprobe.core.compare.acc_compare import Comparator
10
- from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
34
+ from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
11
35
 
12
36
 
13
37
  class MSComparator(Comparator):
14
- def __init__(self, cell_mapping=None, api_mapping=None):
38
+ """
39
+ 用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
40
+ cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
41
+ api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
42
+ data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
43
+ is_cross_framework: 是否跨框架。
44
+ """
45
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
15
46
  self.frame_name = MSComparator.__name__
16
47
  self.cell_mapping = cell_mapping
17
48
  self.api_mapping = api_mapping
18
- self.cross_frame = cell_mapping is not None or api_mapping is not None
49
+ self.data_mapping = data_mapping
50
+ if data_mapping:
51
+ self.cross_frame = is_cross_framework
52
+ else:
53
+ self.cross_frame = cell_mapping is not None or api_mapping is not None
19
54
  self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
20
55
  self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
21
56
  if api_mapping is not None:
22
57
  self.ms_to_pt_mapping = self.load_internal_api()
58
+
59
+ if isinstance(self.data_mapping, str) or self.data_mapping is None:
60
+ self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
61
+ elif isinstance(self.data_mapping, dict):
62
+ self.data_mapping_dict = self.data_mapping
63
+ else:
64
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
65
+ f"{type(self.data_mapping)}")
66
+
67
+ @classmethod
68
+ def calc_accuracy(cls, result_df, dump_mode, header):
69
+ condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
70
+ result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
71
+ result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
72
+
73
+ def calc_summary_diff(data_type: str):
74
+ def type_check(val):
75
+ check_series = pd.Series(False, index=val.index)
76
+ val_str = val.astype(str)
77
+ check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
78
+ return check_series
23
79
 
80
+ def get_number(val):
81
+ return pd.to_numeric(val.astype(str), errors='coerce')
82
+
83
+ ms_val = result_df['NPU ' + data_type]
84
+ pt_val = result_df['Bench ' + data_type]
85
+ diff_name = data_type.capitalize() + ' diff'
86
+ rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
87
+ condition_na = ~type_check(ms_val) | ~type_check(pt_val)
88
+ result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
89
+ result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
90
+ condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
91
+ condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
92
+ result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
93
+ condition_pt_zero = pt_val == 0
94
+ result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
95
+ condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
96
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
97
+ pt_val[condition_ref_err] * 100)
98
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
99
+ .abs().astype(str) + '%')
100
+ magnitude = get_number(result_df[diff_name]).abs() / (
101
+ pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
102
+ return magnitude > CompareConst.MAGNITUDE
103
+
104
+ if dump_mode == Const.MD5:
105
+ condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
106
+ result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
107
+ result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
108
+ elif dump_mode == Const.SUMMARY:
109
+ warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
110
+ warning_flag = pd.DataFrame(warning_list).all()
111
+ result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
112
+ result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
113
+ result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
114
+ else:
115
+ fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
116
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
117
+ CompareConst.ERROR_MESSAGE]
118
+ result_df.loc[~condition_no_bench, fill_cols] = ''
119
+ result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
120
+ return result_df[header]
121
+
122
+ @classmethod
123
+ def make_result_df(cls, result, stack_mode, dump_mode):
124
+ header = CompareConst.HEAD_OF_COMPARE_MODE[dump_mode]
125
+
126
+ if stack_mode:
127
+ header.append(CompareConst.STACK)
128
+ if dump_mode == Const.ALL:
129
+ header.append(CompareConst.DATA_NAME)
130
+ result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
131
+ 'op_name_y': CompareConst.BENCH_NAME,
132
+ 'dtype_x': CompareConst.NPU_DTYPE,
133
+ 'dtype_y': CompareConst.BENCH_DTYPE,
134
+ 'shape_x': CompareConst.NPU_SHAPE,
135
+ 'shape_y': CompareConst.BENCH_SHAPE,
136
+ 'md5_x': CompareConst.NPU_MD5,
137
+ 'md5_y': CompareConst.BENCH_MD5,
138
+ 'data_name_x': CompareConst.DATA_NAME,
139
+ 'stack_info_x': CompareConst.STACK}, inplace=True)
140
+
141
+ npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
142
+ bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
143
+ CompareConst.BENCH_NORM]
144
+ def set_summary(summary):
145
+ if summary == CompareConst.N_A:
146
+ return [CompareConst.N_A] * 4
147
+ summary_list = []
148
+ for i in summary:
149
+ if i is None:
150
+ summary_list.append(CompareConst.N_A)
151
+ elif str(i).lower() == 'nan':
152
+ summary_list.append(CompareConst.NAN)
153
+ else:
154
+ summary_list.append(i)
155
+ return summary_list
156
+
157
+ result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
158
+ result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
159
+ result_df = pd.DataFrame(columns=header)
160
+ for h in header:
161
+ if h in result.columns:
162
+ result_df[h] = result[h]
163
+ return cls.calc_accuracy(result_df, dump_mode, header)
164
+
24
165
  def load_internal_api(self):
25
166
  cur_path = os.path.dirname(os.path.realpath(__file__))
26
- yaml_path = os.path.join(cur_path,"ms_to_pt_api.yaml")
167
+ yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
27
168
  return load_yaml(yaml_path)
28
169
 
29
170
  def load_mapping_file(self, mapping_file):
@@ -34,171 +175,184 @@ class MSComparator(Comparator):
34
175
  return mapping_dict
35
176
 
36
177
  def process_cell_mapping(self, npu_op_name):
37
- npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
178
+ if not npu_op_name or not re.match(r'.+(?:for|back)ward\..+', npu_op_name):
179
+ return CompareConst.N_A
180
+ npu_op_name = npu_op_name.replace("Cell", "Module", 1)
38
181
  if self.cell_mapping_dict:
39
- for index, op_name in enumerate(npu_op_name):
40
- # get cell name & class name from op_name
41
- # Cell.fc1.Dense.forward.0.input.0
42
- cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
43
- if cell_name in self.cell_mapping_dict:
44
- npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
182
+ # get cell name & class name from op_name
183
+ # Cell.fc1.Dense.forward.0.input.0
184
+ cell_name = re.split(r'\.(?:for|back)ward\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
185
+ if cell_name in self.cell_mapping_dict:
186
+ npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
45
187
  return npu_op_name
46
188
 
47
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
48
- npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
49
- npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
50
- if self.cell_mapping is not None:
51
- npu_op_name = self.process_cell_mapping(npu_op_name)
52
- if self.api_mapping is not None:
53
- npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
54
- if isinstance(self.api_mapping, str):
55
- npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new, bench_dict_new)
56
- if target_dict:
57
- bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
58
- npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
59
- struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
60
- if not fuzzy_match:
61
- return npu_op_name == bench_op_name and struct_match
62
- is_match = True
63
- try:
64
- is_match = fuzzy_check_op(npu_op_name, bench_op_name)
65
- except Exception as err:
66
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
67
- is_match = False
68
- return is_match and struct_match
69
-
70
189
  def read_npy_data(self, dir_path, file_name, load_pt_file=False):
190
+ if not file_name:
191
+ return None
71
192
  data_path = os.path.join(dir_path, file_name)
72
193
  if load_pt_file:
73
194
  import torch
74
195
  from msprobe.pytorch.common.utils import load_pt
75
- data_value = load_pt(data_path).detach()
196
+ data_value = load_pt(data_path, True).detach()
76
197
  if data_value.dtype == torch.bfloat16:
77
198
  data_value = data_value.to(torch.float32)
78
199
  data_value = data_value.numpy()
79
200
  else:
80
201
  data_value = load_npy(data_path)
81
- return data_value
202
+ return data_value
82
203
 
83
- def api_replace(self, npu_op_name, target, para):
84
- for idx, _ in enumerate(npu_op_name):
85
- npu_op_name[idx] = npu_op_name[idx].replace(target, para)
86
- return npu_op_name
87
-
88
- def process_internal_api_mapping(self, npu_op_name, bench_op_name):
204
+ def process_internal_api_mapping(self, npu_op_name):
89
205
  # get api name & class name from op_name
90
206
  # Functional.addcmul.0.forward.input.0
91
- npu_op_name, bench_op_name = npu_op_name.copy(), bench_op_name.copy()
92
- ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
93
- pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
207
+ ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
94
208
  class_name = ms_api_name.split(Const.SEP)[0]
95
209
  if class_name == "Mint":
96
- return self.api_replace(npu_op_name, "Mint", "Torch")
210
+ return npu_op_name.replace("Mint", "Torch")
97
211
  elif class_name == "MintFunctional":
98
- return self.api_replace(npu_op_name, "MintFunctional", "Functional")
99
- elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
100
- return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
212
+ return npu_op_name.replace("MintFunctional", "Functional")
213
+ elif self.ms_to_pt_mapping.get(ms_api_name):
214
+ return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
101
215
  else:
102
- return npu_op_name
103
-
104
- def remove_element(self, op_name, struct, summary, idx):
105
- del op_name[idx]
106
- del struct[idx]
107
- del summary[idx]
216
+ return npu_op_name
108
217
 
109
218
  def get_api_name(self, api_list):
110
- return api_list[0] + Const.SEP + api_list[1]
111
-
112
- def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
113
- """
114
- Transform user mapping API based on new NPU and benchmark dictionaries.
115
- Parameters:
116
- new_npu_dict (dict): New NPU operation dictionary.
117
- new_bench_dict (dict): New benchmark operation dictionary.
118
- Returns:
119
- tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
120
- """
121
- npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
122
- npu_struct_in, bench_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT), new_bench_dict.get(CompareConst.INPUT_STRUCT)
123
- npu_struct_out, bench_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT), new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
124
- npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
125
- npu_in_len, bench_in_len, npu_out_len, bench_out_len = len(npu_struct_in), len(bench_struct_in), len(npu_struct_out), len(bench_struct_out)
126
- ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
127
- ms_api_name = self.get_api_name(ms_api_list)
128
- pt_api_name = self.get_api_name(pt_api_list)
129
- target_dict = {}
130
- for api_dict in self.api_mapping_dict:
131
- if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
132
- ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
133
- ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
134
- if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
135
- logger.warning("The user-defined mapping table is incorrect, make sure that the number of parameters is equal" )
136
- break
137
- ms_out_list = api_dict.get("ms_output", [])
138
- for idx in reversed(range(npu_out_len)):
139
- if idx not in ms_out_list:
140
- del npu_struct_out[idx]
141
- del npu_summary[idx + npu_in_len]
142
- del npu_op_name[idx + npu_in_len]
143
- pt_out_list = api_dict.get("pt_output", [])
144
- for idx in reversed(range(bench_out_len)):
145
- if idx not in pt_out_list:
146
- del bench_struct_out[idx]
147
- del bench_summary[idx + bench_in_len]
148
- del bench_op_name[idx + bench_in_len]
149
- ms_para_list = api_dict.get("ms_args", [])
150
- for idx in reversed(range(npu_in_len)):
151
- if idx not in ms_para_list:
152
- self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
153
- pt_para_list = api_dict.get("pt_args", [])
154
- for idx in reversed(range(bench_in_len)):
155
- if idx not in pt_para_list:
156
- self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
157
- npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
158
- npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
159
- target_dict = api_dict
160
- break
161
- if target_dict:
162
- new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in, CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
163
- new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
164
- return new_npu_dict, new_bench_dict, target_dict
165
-
166
- def para_sequence_update(self, npu_op_name, bench_op_name):
167
- for idx, _ in enumerate(npu_op_name):
168
- bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
169
- if len(bench_op_name_list) != 0:
170
- npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
171
- return npu_op_name
219
+ try:
220
+ api_name = api_list[0] + Const.SEP + api_list[1]
221
+ except IndexError as error:
222
+ logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
223
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
224
+ return api_name
172
225
 
173
- def reconstitution_bench_dict(self, npu_dict, del_bench_dict, api_dict):
174
- ms_user_args_list = api_dict.get("ms_args", [])
175
- ms_user_output_list = api_dict.get("ms_output", [])
176
- npu_struct_in = npu_dict.get(CompareConst.INPUT_STRUCT)
177
- npu_struct_out = npu_dict.get(CompareConst.OUTPUT_STRUCT)
178
- npu_in_len = len(npu_struct_in)
179
- npu_out_len = len(npu_struct_out)
180
- if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
181
- return del_bench_dict
182
- ms_input_args_list = [i for i in range(npu_in_len)]
183
- input_sub_list =list(set(ms_input_args_list) - set(ms_user_args_list))
184
- ms_output_args_list = [i for i in range(npu_out_len)]
185
- output_sub_list =list(set(ms_output_args_list) - set(ms_user_output_list))
186
- bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
187
- bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
188
- bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
189
- bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
190
- for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
191
- bench_op_name.insert(idx, CompareConst.NAN)
192
- bench_struct_in.insert(idx, CompareConst.NAN)
193
- bench_summary.insert(idx, CompareConst.NAN)
194
- for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
195
- bench_op_name.insert(npu_in_len + idx, CompareConst.NAN)
196
- bench_struct_out.insert(idx, CompareConst.NAN)
197
- bench_summary.insert(npu_in_len + idx, CompareConst.NAN)
198
- del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in, CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
199
- return del_bench_dict
200
-
226
+ def compare_process(self, file_lists, stack_mode, fuzzy_match, dump_mode):
227
+ npu_json_path, bench_json_path, stack_json_path = file_lists
228
+ npu_json_data = load_json(npu_json_path)
229
+ bench_json_data = load_json(bench_json_path)
230
+ stack_json_data = load_json(stack_json_path)
231
+
232
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data, dump_mode)
233
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data, dump_mode)
234
+ if self.cell_mapping:
235
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
236
+ elif self.api_mapping:
237
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
238
+ if isinstance(self.api_mapping, str):
239
+ self.modify_compare_data_with_user_mapping(npu_df, bench_df)
240
+ else:
241
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
242
+ npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
243
+ bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
244
+ npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
245
+ bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
246
+ bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
247
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
248
+ how='outer')
249
+ match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
250
+
251
+ def gen_dtype_condition():
252
+ npu_dtype = match_result['dtype_x']
253
+ bench_dtype = match_result['dtype_y']
254
+ if self.cross_frame:
255
+ npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
256
+ return ((npu_dtype == bench_dtype) |
257
+ ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
258
+ ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
259
+ ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
260
+ ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
261
+ ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
262
+ ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
263
+ ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
264
+ ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
201
265
 
266
+ match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
267
+ return MSComparator.make_result_df(match_result, stack_mode, dump_mode)
268
+
269
+ def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
270
+ def get_api_indices_dict(op_name_df):
271
+ api_indices_dict = defaultdict(list)
272
+ for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
273
+ api = self.get_api_name(name.split(Const.SEP))
274
+ api_indices_dict[api].append(op_index)
275
+ return api_indices_dict
276
+
277
+ ms_api_indices_dict = get_api_indices_dict(npu_df)
278
+ pt_api_indices_dict = get_api_indices_dict(bench_df)
279
+
280
+ def gen_input_compare_key(pattern, term):
281
+ flag = True
282
+ for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
283
+ if op_name.split(pattern)[1].startswith(str(prefix)):
284
+ npu_df.loc[index, CompareConst.COMPARE_KEY] = (
285
+ op_name.replace(pattern + str(prefix),
286
+ pattern + str(mapping_dict.get(f'pt_{term}')[i])))
287
+ flag = False
288
+ return flag
289
+
290
+ for mapping_dict in self.api_mapping_dict:
291
+ if (len(mapping_dict.get('ms_args')) != len(mapping_dict.get('pt_args')) or
292
+ len(mapping_dict.get('ms_output')) != len(mapping_dict.get('pt_output'))):
293
+ logger.warning('The user-defined mapping table is incorrect,\
294
+ make sure that the number of parameters is equal')
295
+ continue
296
+ ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
297
+ if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
298
+ continue
299
+ for index in ms_api_indices_dict.get(ms_api):
300
+ op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
301
+ if CompareConst.INPUT_PATTERN in op_name:
302
+ is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
303
+ elif CompareConst.KWARGS_PATTERN in op_name:
304
+ is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
305
+ elif CompareConst.OUTPUT_PATTERN in op_name:
306
+ is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
307
+ else:
308
+ logger.error(f'Excepted op_name: {op_name}')
309
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
310
+ if is_abandoned:
311
+ npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
312
+
313
+ def gen_data_df(self, data_json, stack_json, dump_mode):
314
+ result = {
315
+ CompareConst.OP_NAME: [],
316
+ Const.DTYPE: [],
317
+ Const.SHAPE: [],
318
+ Const.SUMMARY: [],
319
+ 'stack_info': []
320
+ }
321
+ if dump_mode == Const.ALL:
322
+ result['data_name'] = []
323
+ elif dump_mode == Const.MD5:
324
+ result[Const.MD5] = []
325
+ for data_name in data_json['data']:
326
+ check_op_str_pattern_valid(data_name)
327
+ merge_list = self.gen_merge_list(data_json, data_name, stack_json, dump_mode)
328
+ if not merge_list:
329
+ continue
330
+ for op_name in merge_list[CompareConst.OP_NAME]:
331
+ result[CompareConst.OP_NAME].append(op_name)
332
+ if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
333
+ struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
334
+ else:
335
+ struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
336
+ result[Const.DTYPE].append(struct[0])
337
+ result[Const.SHAPE].append(struct[1])
338
+ if dump_mode == Const.MD5:
339
+ result[Const.MD5].append(struct[2])
340
+ result[Const.SUMMARY].append(merge_list[Const.SUMMARY].pop(0))
341
+ result['stack_info'].append(merge_list['stack_info'][0])
342
+ if dump_mode == Const.ALL:
343
+ result['data_name'].append(merge_list['data_name'].pop(0))
344
+ return pd.DataFrame(result)
345
+
346
+
347
+ def check_cross_framework(bench_json_path):
348
+ pattern = r'"data_name":\s*"[^"]+\.pt"'
349
+ with FileOpen(bench_json_path, 'r') as file:
350
+ for line in file:
351
+ if re.search(pattern, line):
352
+ return True
353
+ return False
354
+
355
+
202
356
  def ms_compare(input_param, output_path, **kwargs):
203
357
  try:
204
358
  stack_mode = kwargs.get('stack_mode', False)
@@ -206,14 +360,21 @@ def ms_compare(input_param, output_path, **kwargs):
206
360
  fuzzy_match = kwargs.get('fuzzy_match', False)
207
361
  cell_mapping = kwargs.get('cell_mapping', None)
208
362
  api_mapping = kwargs.get('api_mapping', None)
209
- summary_compare, md5_compare = task_dumppath_get(input_param)
210
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
363
+ data_mapping = kwargs.get('data_mapping', None)
364
+ layer_mapping = kwargs.get('layer_mapping', None)
365
+ suffix = kwargs.get('suffix', '')
366
+
367
+ set_dump_path(input_param)
368
+ dump_mode = get_dump_mode(input_param)
369
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
211
370
  create_directory(output_path)
212
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
371
+ check_compare_param(input_param, output_path, dump_mode)
213
372
  except (CompareException, FileCheckException) as error:
214
373
  logger.error('Compare failed. Please check the arguments and do it again!')
215
374
  raise CompareException(error.code) from error
216
- ms_comparator = MSComparator(cell_mapping, api_mapping)
217
- ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
218
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
219
- md5_compare=md5_compare)
375
+ if layer_mapping:
376
+ data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
377
+ is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
378
+ ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
379
+ ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode, suffix=suffix,
380
+ auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, dump_mode=dump_mode)