mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,33 +1,73 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import re
3
- import copy
4
- import sys
5
- from itertools import zip_longest
6
-
7
- from msprobe.core.common.utils import check_compare_param, CompareException, check_configuration_param, \
8
- task_dumppath_get, struct_json_get, add_time_with_yaml
9
- from msprobe.core.common.file_utils import create_directory, load_yaml, load_npy, load_json, save_yaml, FileOpen
10
- from msprobe.core.common.const import Const, CompareConst
11
- from msprobe.core.common.log import logger
18
+ from collections import defaultdict
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+ from msprobe.core.common.const import CompareConst, Const
12
24
  from msprobe.core.common.exceptions import FileCheckException
13
- from msprobe.core.compare.acc_compare import Comparator
14
- from msprobe.core.compare.check import check_struct_match, fuzzy_check_op
15
- from msprobe.mindspore.compare.modify_mapping import modify_mapping_with_stack
16
- from msprobe.mindspore.compare.layer_mapping import get_layer_mapping
25
+ from msprobe.core.common.file_utils import FileOpen, create_directory, load_json, load_npy, load_yaml
26
+ from msprobe.core.common.log import logger
27
+ from msprobe.core.common.utils import CompareException, check_compare_param, check_configuration_param, \
28
+ check_op_str_pattern_valid, get_dump_mode, set_dump_path
29
+ from msprobe.core.compare.acc_compare import Comparator, ModeConfig
30
+ from msprobe.core.compare.check import dtype_mapping
31
+ from msprobe.core.compare.layer_mapping import generate_data_mapping_by_layer_mapping
32
+ from msprobe.core.compare.utils import set_stack_json_path, reorder_op_x_list
17
33
 
18
- class MSComparator(Comparator):
19
- def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None, is_cross_framework=False):
20
- self.frame_name = MSComparator.__name__
34
+
35
+ class MappingConfig:
36
+ def __init__(self, cell_mapping=None, api_mapping=None, data_mapping=None):
21
37
  self.cell_mapping = cell_mapping
22
38
  self.api_mapping = api_mapping
23
39
  self.data_mapping = data_mapping
24
- if data_mapping:
40
+
41
+
42
+ class MSComparator(Comparator):
43
+ """
44
+ 用于mindspore动态图同框架/跨框架精度比对,支持md5/summary/all模式。
45
+ cell_mapping: mindspore在cell级别(L0)dump数据和pytorch的module之间的映射关系;
46
+ api_mapping: mindspore在api级别(L1)dump数据和pytorch的api之间的映射关系;
47
+ data_mapping: mindspore的cell或api的入参/出参和pytorch之间的映射关系;
48
+ is_cross_framework: 是否跨框架。
49
+ """
50
+ def __init__(self, mode_config, mapping_config=None, is_cross_framework=False):
51
+ super().__init__(mode_config)
52
+ self.frame_name = MSComparator.__name__
53
+
54
+ self.stack_mode = mode_config.stack_mode
55
+ self.auto_analyze = mode_config.auto_analyze
56
+ self.fuzzy_match = mode_config.fuzzy_match
57
+ self.dump_mode = mode_config.dump_mode
58
+
59
+ if mapping_config:
60
+ self.cell_mapping = mapping_config.cell_mapping
61
+ self.api_mapping = mapping_config.api_mapping
62
+ self.data_mapping = mapping_config.data_mapping
63
+
64
+ if self.data_mapping:
25
65
  self.cross_frame = is_cross_framework
26
66
  else:
27
- self.cross_frame = cell_mapping is not None or api_mapping is not None
67
+ self.cross_frame = self.cell_mapping is not None or self.api_mapping is not None
28
68
  self.cell_mapping_dict = self.load_mapping_file(self.cell_mapping)
29
69
  self.api_mapping_dict = self.load_mapping_file(self.api_mapping)
30
- if api_mapping is not None:
70
+ if self.api_mapping is not None:
31
71
  self.ms_to_pt_mapping = self.load_internal_api()
32
72
 
33
73
  if isinstance(self.data_mapping, str) or self.data_mapping is None:
@@ -38,9 +78,106 @@ class MSComparator(Comparator):
38
78
  raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
39
79
  f"{type(self.data_mapping)}")
40
80
 
81
+ def calc_accuracy(self, result_df, header):
82
+ condition_no_bench = result_df[CompareConst.BENCH_NAME] == CompareConst.N_A
83
+ result_df[condition_no_bench] = result_df[condition_no_bench].fillna(CompareConst.N_A)
84
+ result_df.loc[condition_no_bench, CompareConst.ERROR_MESSAGE] = CompareConst.NO_BENCH
85
+
86
+ def calc_summary_diff(data_type: str):
87
+ def type_check(val):
88
+ check_series = pd.Series(False, index=val.index)
89
+ val_str = val.astype(str)
90
+ check_series[pd.to_numeric(val_str, errors='coerce').notna() | val_str.str.lower().eq('nan')] = True
91
+ return check_series
92
+
93
+ def get_number(val):
94
+ return pd.to_numeric(val.astype(str), errors='coerce')
95
+
96
+ ms_val = result_df['NPU ' + data_type]
97
+ pt_val = result_df['Bench ' + data_type]
98
+ diff_name = data_type.capitalize() + ' diff'
99
+ rel_err_name = ('norm' if data_type == 'l2norm' else data_type).capitalize() + 'RelativeErr'
100
+ condition_na = ~type_check(ms_val) | ~type_check(pt_val)
101
+ result_df.loc[condition_na, [diff_name, rel_err_name]] = CompareConst.N_A
102
+ result_df.loc[~(condition_no_bench | condition_na), diff_name] = get_number(ms_val) - get_number(pt_val)
103
+ condition_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].isna()
104
+ condition_not_nan_diff = ~condition_no_bench & ~condition_na & result_df[diff_name].notna()
105
+ result_df.loc[condition_nan_diff, [diff_name, rel_err_name]] = CompareConst.NAN
106
+ condition_pt_zero = pt_val == 0
107
+ result_df.loc[condition_not_nan_diff & condition_pt_zero, rel_err_name] = CompareConst.NAN
108
+ condition_ref_err = condition_not_nan_diff & ~condition_pt_zero
109
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, diff_name] /
110
+ pt_val[condition_ref_err] * 100)
111
+ result_df.loc[condition_ref_err, rel_err_name] = (result_df.loc[condition_ref_err, rel_err_name]
112
+ .abs().astype(str) + '%')
113
+ magnitude = get_number(result_df[diff_name]).abs() / (
114
+ pd.Series(np.maximum(get_number(ms_val), get_number(pt_val))).abs() + CompareConst.EPSILON)
115
+ return magnitude > CompareConst.MAGNITUDE
116
+
117
+ if self.dump_mode == Const.MD5:
118
+ condition_md5_equal = result_df[CompareConst.NPU_MD5] == result_df[CompareConst.BENCH_MD5]
119
+ result_df.loc[condition_md5_equal, CompareConst.RESULT] = CompareConst.PASS
120
+ result_df.loc[~condition_md5_equal & ~condition_no_bench, CompareConst.RESULT] = CompareConst.DIFF
121
+ elif self.dump_mode == Const.SUMMARY:
122
+ warning_list = [calc_summary_diff(data_type) for data_type in ['max', 'min', 'mean', 'l2norm']]
123
+ warning_flag = pd.DataFrame(warning_list).all()
124
+ result_df.loc[~condition_no_bench, [CompareConst.RESULT, CompareConst.ERROR_MESSAGE]] = ''
125
+ result_df.loc[warning_flag, CompareConst.RESULT] = CompareConst.WARNING
126
+ result_df.loc[warning_flag, CompareConst.ERROR_MESSAGE] = 'Need double check api accuracy.'
127
+ else:
128
+ fill_cols = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR,
129
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO,
130
+ CompareConst.ERROR_MESSAGE]
131
+ result_df.loc[~condition_no_bench, fill_cols] = ''
132
+ result_df.loc[~condition_no_bench, CompareConst.ACCURACY] = CompareConst.ACCURACY_CHECK_YES
133
+ return result_df[header]
134
+
135
+ def make_result_df(self, result):
136
+ header = CompareConst.HEAD_OF_COMPARE_MODE[self.dump_mode][:]
137
+
138
+ if self.stack_mode:
139
+ header.append(CompareConst.STACK)
140
+ if self.dump_mode == Const.ALL:
141
+ header.append(CompareConst.DATA_NAME)
142
+ result.rename(columns={'op_name_x': CompareConst.NPU_NAME,
143
+ 'op_name_y': CompareConst.BENCH_NAME,
144
+ 'dtype_x': CompareConst.NPU_DTYPE,
145
+ 'dtype_y': CompareConst.BENCH_DTYPE,
146
+ 'shape_x': CompareConst.NPU_SHAPE,
147
+ 'shape_y': CompareConst.BENCH_SHAPE,
148
+ 'md5_x': CompareConst.NPU_MD5,
149
+ 'md5_y': CompareConst.BENCH_MD5,
150
+ 'data_name_x': CompareConst.DATA_NAME,
151
+ 'stack_info_x': CompareConst.STACK}, inplace=True)
152
+
153
+ npu_summary = [CompareConst.NPU_MAX, CompareConst.NPU_MIN, CompareConst.NPU_MEAN, CompareConst.NPU_NORM]
154
+ bench_summary = [CompareConst.BENCH_MAX, CompareConst.BENCH_MIN, CompareConst.BENCH_MEAN,
155
+ CompareConst.BENCH_NORM]
156
+
157
+ def set_summary(summary):
158
+ if summary == CompareConst.N_A:
159
+ return [CompareConst.N_A] * 4
160
+ summary_list = []
161
+ for i in summary:
162
+ if i is None:
163
+ summary_list.append(CompareConst.N_A)
164
+ elif str(i).lower() == 'nan':
165
+ summary_list.append(CompareConst.NAN)
166
+ else:
167
+ summary_list.append(i)
168
+ return summary_list
169
+
170
+ result[npu_summary] = result['summary_x'].apply(set_summary).tolist()
171
+ result[bench_summary] = result['summary_y'].apply(set_summary).tolist()
172
+ result_df = pd.DataFrame(columns=header)
173
+ for h in header:
174
+ if h in result.columns:
175
+ result_df[h] = result[h]
176
+ return self.calc_accuracy(result_df, header)
177
+
41
178
  def load_internal_api(self):
42
179
  cur_path = os.path.dirname(os.path.realpath(__file__))
43
- yaml_path = os.path.join(cur_path, "ms_to_pt_api.yaml")
180
+ yaml_path = os.path.abspath(os.path.join(cur_path, CompareConst.INTERNAL_API_MAPPING_FILE))
44
181
  return load_yaml(yaml_path)
45
182
 
46
183
  def load_mapping_file(self, mapping_file):
@@ -51,42 +188,23 @@ class MSComparator(Comparator):
51
188
  return mapping_dict
52
189
 
53
190
  def process_cell_mapping(self, npu_op_name):
54
- npu_op_name = [op_name.replace("Cell", "Module", 1) for op_name in npu_op_name]
191
+ if not npu_op_name:
192
+ return CompareConst.N_A
193
+ param_grad_flag = Const.PARAMS_GRAD in npu_op_name.split(Const.SEP)
194
+ if not param_grad_flag and not re.search(Const.REGEX_FORWARD_BACKWARD, npu_op_name):
195
+ return CompareConst.N_A
196
+ npu_op_name = npu_op_name.replace("Cell", "Module", 1)
55
197
  if self.cell_mapping_dict:
56
- for index, op_name in enumerate(npu_op_name):
57
- # get cell name & class name from op_name
58
- # Cell.fc1.Dense.forward.0.input.0
59
- cell_name = op_name.split(Const.SEP, 1)[-1].rsplit(Const.SEP, 4)[0]
60
- if cell_name in self.cell_mapping_dict:
61
- npu_op_name[index] = op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
198
+ # get cell name & class name from op_name
199
+ # Cell.fc1.Dense.forward.0.input.0
200
+ cell_name = re.split(r'\.(?:forward|backward|parameters_grad)\.', npu_op_name.split(Const.SEP, 1)[-1])[0]
201
+ if cell_name in self.cell_mapping_dict:
202
+ npu_op_name = npu_op_name.replace(cell_name, self.cell_mapping_dict[cell_name], 1)
62
203
  return npu_op_name
63
204
 
64
- def check_op(self, npu_dict, bench_dict, fuzzy_match):
65
- npu_dict_new, bench_dict_new = copy.deepcopy(npu_dict), copy.deepcopy(bench_dict)
66
- npu_op_name, bench_op_name = npu_dict_new.get(CompareConst.OP_NAME), bench_dict_new.get(CompareConst.OP_NAME)
67
- if self.cell_mapping is not None:
68
- npu_op_name = self.process_cell_mapping(npu_op_name)
69
- if self.api_mapping is not None:
70
- npu_op_name = self.process_internal_api_mapping(npu_op_name, bench_op_name)
71
- if isinstance(self.api_mapping, str):
72
- npu_dict_new, bench_dict_new, target_dict = self.transform_user_mapping_api(npu_dict_new,
73
- bench_dict_new)
74
- if target_dict:
75
- bench_dict = self.reconstitution_bench_dict(npu_dict, copy.deepcopy(bench_dict_new), target_dict)
76
- npu_op_name = npu_dict_new.get(CompareConst.OP_NAME)
77
- bench_op_name = bench_dict_new.get(CompareConst.OP_NAME)
78
- struct_match = check_struct_match(npu_dict_new, bench_dict_new, cross_frame=self.cross_frame)
79
- if not fuzzy_match:
80
- return npu_op_name == bench_op_name and struct_match
81
- is_match = True
82
- try:
83
- is_match = fuzzy_check_op(npu_op_name, bench_op_name)
84
- except Exception as err:
85
- logger.warning("%s and %s can not fuzzy match." % (npu_op_name, bench_op_name))
86
- is_match = False
87
- return is_match and struct_match
88
-
89
205
  def read_npy_data(self, dir_path, file_name, load_pt_file=False):
206
+ if not file_name:
207
+ return None
90
208
  data_path = os.path.join(dir_path, file_name)
91
209
  if load_pt_file:
92
210
  import torch
@@ -96,35 +214,23 @@ class MSComparator(Comparator):
96
214
  data_value = data_value.to(torch.float32)
97
215
  data_value = data_value.numpy()
98
216
  else:
99
- data_value = load_npy(data_path)
100
- return data_value
217
+ data_value = load_npy(data_path)
218
+ return data_value
101
219
 
102
- def api_replace(self, npu_op_name, target, para):
103
- for idx, _ in enumerate(npu_op_name):
104
- npu_op_name[idx] = npu_op_name[idx].replace(target, para)
105
- return npu_op_name
106
-
107
- def process_internal_api_mapping(self, npu_op_name, bench_op_name):
220
+ def process_internal_api_mapping(self, npu_op_name):
108
221
  # get api name & class name from op_name
109
222
  # Functional.addcmul.0.forward.input.0
110
- npu_op_name, bench_op_name = npu_op_name.copy(), bench_op_name.copy()
111
- ms_api_name = self.get_api_name(npu_op_name[0].split(Const.SEP))
112
- pt_api_name = self.get_api_name(bench_op_name[0].split(Const.SEP))
223
+ ms_api_name = self.get_api_name(npu_op_name.split(Const.SEP))
113
224
  class_name = ms_api_name.split(Const.SEP)[0]
114
225
  if class_name == "Mint":
115
- return self.api_replace(npu_op_name, "Mint", "Torch")
226
+ return npu_op_name.replace("Mint", "Torch")
116
227
  elif class_name == "MintFunctional":
117
- return self.api_replace(npu_op_name, "MintFunctional", "Functional")
118
- elif self.ms_to_pt_mapping.get(ms_api_name) == pt_api_name:
119
- return self.api_replace(npu_op_name, ms_api_name, pt_api_name)
228
+ return npu_op_name.replace("MintFunctional", "Functional")
229
+ elif self.ms_to_pt_mapping.get(ms_api_name):
230
+ return npu_op_name.replace(ms_api_name, self.ms_to_pt_mapping.get(ms_api_name))
120
231
  else:
121
232
  return npu_op_name
122
-
123
- def remove_element(self, op_name, struct, summary, idx):
124
- del op_name[idx]
125
- del struct[idx]
126
- del summary[idx]
127
-
233
+
128
234
  def get_api_name(self, api_list):
129
235
  try:
130
236
  api_name = api_list[0] + Const.SEP + api_list[1]
@@ -132,184 +238,147 @@ class MSComparator(Comparator):
132
238
  logger.error(f'Failed to retrieve API name, please check if the dump data is reasonable')
133
239
  raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
134
240
  return api_name
135
-
136
- def transform_user_mapping_api(self, new_npu_dict, new_bench_dict):
137
- """
138
- Transform user mapping API based on new NPU and benchmark dictionaries.
139
- Parameters:
140
- new_npu_dict (dict): New NPU operation dictionary.
141
- new_bench_dict (dict): New benchmark operation dictionary.
142
- Returns:
143
- tuple: Updated NPU and benchmark dictionaries, along with the target dictionary.
144
- """
145
- npu_op_name, bench_op_name = new_npu_dict.get(CompareConst.OP_NAME), new_bench_dict.get(CompareConst.OP_NAME)
146
- npu_struct_in = new_npu_dict.get(CompareConst.INPUT_STRUCT)
147
- bench_struct_in = new_bench_dict.get(CompareConst.INPUT_STRUCT)
148
- npu_struct_out = new_npu_dict.get(CompareConst.OUTPUT_STRUCT)
149
- bench_struct_out = new_bench_dict.get(CompareConst.OUTPUT_STRUCT)
150
- npu_summary, bench_summary = new_npu_dict.get(CompareConst.SUMMARY), new_bench_dict.get(CompareConst.SUMMARY)
151
- npu_in_len, bench_in_len = len(npu_struct_in), len(bench_struct_in)
152
- npu_out_len, bench_out_len = len(npu_struct_out), len(bench_struct_out)
153
- ms_api_list, pt_api_list = npu_op_name[0].split(Const.SEP), bench_op_name[0].split(Const.SEP)
154
- ms_api_name = self.get_api_name(ms_api_list)
155
- pt_api_name = self.get_api_name(pt_api_list)
156
- target_dict = {}
157
- for api_dict in self.api_mapping_dict:
158
- if api_dict.get("pt_api") == pt_api_name and api_dict.get("ms_api") == ms_api_name:
159
- ms_user_args_len, pt_user_args_len = len(api_dict.get("ms_args")), len(api_dict.get("pt_args"))
160
- ms_user_output_len, pt_user_output_len = len(api_dict.get("ms_output")), len(api_dict.get("pt_output"))
161
- if ms_user_args_len != pt_user_args_len or ms_user_output_len != pt_user_output_len:
162
- logger.warning("The user-defined mapping table is incorrect,\
163
- make sure that the number of parameters is equal")
164
- break
165
- ms_out_list = api_dict.get("ms_output", [])
166
- for idx in reversed(range(npu_out_len)):
167
- if idx not in ms_out_list:
168
- del npu_struct_out[idx]
169
- if idx + npu_in_len < len(npu_summary) and idx + npu_in_len < len(npu_op_name):
170
- del npu_summary[idx + npu_in_len]
171
- del npu_op_name[idx + npu_in_len]
172
- pt_out_list = api_dict.get("pt_output", [])
173
- for idx in reversed(range(bench_out_len)):
174
- if idx not in pt_out_list:
175
- del bench_struct_out[idx]
176
- if idx + bench_in_len < len(bench_summary) and idx + bench_in_len < len(bench_op_name):
177
- del bench_summary[idx + bench_in_len]
178
- del bench_op_name[idx + bench_in_len]
179
- ms_para_list = api_dict.get("ms_args", [])
180
- for idx in reversed(range(npu_in_len)):
181
- if idx not in ms_para_list:
182
- self.remove_element(npu_op_name, npu_struct_in, npu_summary, idx)
183
- pt_para_list = api_dict.get("pt_args", [])
184
- for idx in reversed(range(bench_in_len)):
185
- if idx not in pt_para_list:
186
- self.remove_element(bench_op_name, bench_struct_in, bench_summary, idx)
187
- npu_op_name = self.api_replace(npu_op_name, ms_api_name, pt_api_name)
188
- npu_op_name = self.para_sequence_update(npu_op_name, bench_op_name)
189
- target_dict = api_dict
190
- break
191
- if target_dict:
192
- new_npu_dict.update({CompareConst.OP_NAME: npu_op_name, CompareConst.INPUT_STRUCT: npu_struct_in,
193
- CompareConst.OUTPUT_STRUCT: npu_struct_out, CompareConst.SUMMARY: npu_summary})
194
- new_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
195
- CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
196
- return new_npu_dict, new_bench_dict, target_dict
197
-
198
- def para_sequence_update(self, npu_op_name, bench_op_name):
199
- for idx, _ in enumerate(npu_op_name):
200
- bench_op_name_list = bench_op_name[idx].rsplit(Const.SEP, 1)
201
- if len(bench_op_name_list) != 0:
202
- npu_op_name[idx] = npu_op_name[idx][:-1] + bench_op_name_list[-1]
203
- return npu_op_name
204
241
 
205
- def reconstitution_bench_dict(self, npu_dict, del_bench_dict, api_dict):
206
- ms_user_args_list = api_dict.get("ms_args", [])
207
- ms_user_output_list = api_dict.get("ms_output", [])
208
- npu_struct_in = npu_dict.get(CompareConst.INPUT_STRUCT)
209
- npu_struct_out = npu_dict.get(CompareConst.OUTPUT_STRUCT)
210
- npu_in_len = len(npu_struct_in)
211
- npu_out_len = len(npu_struct_out)
212
- if npu_in_len == len(ms_user_args_list) and npu_out_len == len(ms_user_output_list):
213
- return del_bench_dict
214
- ms_input_args_list = [i for i in range(npu_in_len)]
215
- input_sub_list = list(set(ms_input_args_list) - set(ms_user_args_list))
216
- ms_output_args_list = [i for i in range(npu_out_len)]
217
- output_sub_list = list(set(ms_output_args_list) - set(ms_user_output_list))
218
- bench_op_name = del_bench_dict.get(CompareConst.OP_NAME, [])
219
- bench_struct_in = del_bench_dict.get(CompareConst.INPUT_STRUCT, [])
220
- bench_struct_out = del_bench_dict.get(CompareConst.OUTPUT_STRUCT, [])
221
- bench_summary = del_bench_dict.get(CompareConst.SUMMARY, [])
222
- for idx in input_sub_list: # Fill in the blank value field in the pt dictionary
223
- bench_op_name.insert(idx, CompareConst.N_A)
224
- bench_struct_in.insert(idx, CompareConst.N_A)
225
- bench_summary.insert(idx, CompareConst.N_A)
226
- for idx in output_sub_list: # Fill in the blank value field in the pt dictionary
227
- bench_op_name.insert(npu_in_len + idx, CompareConst.N_A)
228
- bench_struct_out.insert(idx, CompareConst.N_A)
229
- bench_summary.insert(npu_in_len + idx, CompareConst.N_A)
230
- del_bench_dict.update({CompareConst.OP_NAME: bench_op_name, CompareConst.INPUT_STRUCT: bench_struct_in,
231
- CompareConst.OUTPUT_STRUCT: bench_struct_out, CompareConst.SUMMARY: bench_summary})
232
- return del_bench_dict
233
-
234
-
235
- def sort_by_execution_sequence(npu_data, bench_data, mapping_list, flag):
236
- def generate_execution_sequence(data):
237
- sequence_map = {}
238
- for index, item in enumerate(data.keys()):
239
- if flag in item:
240
- item_split = item.split(Const.SEP)
241
- item_name = Const.SEP.join(item_split[0:-2])
242
- item_index = item_split[-1]
243
- if item_index == 'forward' or item_index == 'backward':
244
- item_index = item_split[-2]
245
- item_key = f"{item_name}.{item_index}"
246
- sequence_map[item_key] = index
247
- return sequence_map
248
-
249
- npu_map = generate_execution_sequence(npu_data)
250
- bench_map = generate_execution_sequence(bench_data)
251
-
252
- def sort_by_map(item):
253
- first_key = npu_map.get(item[0], sys.maxsize)
254
- second_key = bench_map.get(item[1], sys.maxsize)
255
- return first_key, second_key
256
-
257
- return sorted(mapping_list, key=sort_by_map)
258
-
259
-
260
- def generate_kernel_data(map_value, data, flag):
261
- if not map_value:
262
- return [], []
263
- inputs_name = []
264
- outputs_name = []
265
- map_split = map_value.split(Const.SEP)
266
- map_name = Const.SEP.join(map_split[0:-1])
267
- map_index = map_split[-1]
268
- for key, value in data.items():
269
- if key.find(flag) != -1 and key.find(map_name) != -1:
270
- if key.split(Const.SEP)[-1] != map_index and key.split(Const.SEP)[-2] != map_index :
242
+ def compare_process(self, file_lists):
243
+ npu_json_path, bench_json_path, stack_json_path = file_lists
244
+ npu_json_data = load_json(npu_json_path)
245
+ bench_json_data = load_json(bench_json_path)
246
+ stack_json_data = load_json(stack_json_path) if self.stack_mode else None
247
+
248
+ npu_df = self.gen_data_df(npu_json_data, stack_json_data)
249
+ bench_df = self.gen_data_df(bench_json_data, stack_json_data)
250
+ if self.cell_mapping:
251
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_cell_mapping)
252
+ elif self.api_mapping:
253
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME].apply(self.process_internal_api_mapping)
254
+ if isinstance(self.api_mapping, str):
255
+ self.modify_compare_data_with_user_mapping(npu_df, bench_df)
256
+ else:
257
+ npu_df[CompareConst.COMPARE_KEY] = npu_df[CompareConst.OP_NAME]
258
+ npu_df[[Const.DTYPE, Const.SHAPE]] = npu_df[[Const.DTYPE, Const.SHAPE]].astype(str)
259
+ bench_df[[Const.DTYPE, Const.SHAPE]] = bench_df[[Const.DTYPE, Const.SHAPE]].astype(str)
260
+ npu_df[CompareConst.COMPARE_SHAPE] = npu_df[Const.SHAPE]
261
+ bench_df[CompareConst.COMPARE_KEY] = bench_df[CompareConst.OP_NAME]
262
+ bench_df[CompareConst.COMPARE_SHAPE] = bench_df[Const.SHAPE]
263
+ match_result = pd.merge(npu_df, bench_df, on=[CompareConst.COMPARE_KEY, CompareConst.COMPARE_SHAPE],
264
+ how='outer')
265
+ match_result = match_result[match_result['op_name_x'].notna()].fillna(CompareConst.N_A)
266
+
267
+ def gen_dtype_condition():
268
+ npu_dtype = match_result['dtype_x']
269
+ bench_dtype = match_result['dtype_y']
270
+ if self.cross_frame:
271
+ npu_dtype = npu_dtype.map(dtype_mapping).fillna(npu_dtype)
272
+ return ((npu_dtype == bench_dtype) |
273
+ ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.FLOAT32)) |
274
+ ((npu_dtype == Const.FLOAT32) & (bench_dtype == Const.FLOAT16)) |
275
+ ((npu_dtype == Const.FLOAT16) & (bench_dtype == Const.BFLOAT16)) |
276
+ ((npu_dtype == Const.BFLOAT16) & (bench_dtype == Const.FLOAT16)) |
277
+ ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_FLOAT32)) |
278
+ ((npu_dtype == Const.TORCH_FLOAT32) & (bench_dtype == Const.TORCH_FLOAT16)) |
279
+ ((npu_dtype == Const.TORCH_FLOAT16) & (bench_dtype == Const.TORCH_BFLOAT16)) |
280
+ ((npu_dtype == Const.TORCH_BFLOAT16) & (bench_dtype == Const.TORCH_FLOAT16)))
281
+
282
+ match_result.loc[~gen_dtype_condition(), [i + '_y' for i in bench_df.columns]] = CompareConst.N_A
283
+ return self.make_result_df(match_result)
284
+
285
+ def modify_compare_data_with_user_mapping(self, npu_df, bench_df):
286
+ def get_api_indices_dict(op_name_df):
287
+ api_indices_dict = defaultdict(list)
288
+ for op_index, name in enumerate(op_name_df[CompareConst.OP_NAME]):
289
+ api = self.get_api_name(name.split(Const.SEP))
290
+ api_indices_dict[api].append(op_index)
291
+ return api_indices_dict
292
+
293
+ ms_api_indices_dict = get_api_indices_dict(npu_df)
294
+ pt_api_indices_dict = get_api_indices_dict(bench_df)
295
+
296
+ def gen_input_compare_key(pattern, term):
297
+ flag = True
298
+ for i, prefix in enumerate(mapping_dict.get(f'ms_{term}')):
299
+ if op_name.split(pattern)[1].startswith(str(prefix)):
300
+ npu_df.loc[index, CompareConst.COMPARE_KEY] = (
301
+ op_name.replace(pattern + str(prefix),
302
+ pattern + str(mapping_dict.get(f'pt_{term}')[i])))
303
+ flag = False
304
+ return flag
305
+
306
+ for mapping_dict in self.api_mapping_dict:
307
+ keys_to_compare = [
308
+ ('ms_args', 'pt_args'),
309
+ ('ms_output', 'pt_output'),
310
+ ('ms_parameters', 'pt_parameters'),
311
+ ('ms_parameters_grad', 'pt_parameters_grad'),
312
+ ]
313
+ if not all(len(mapping_dict.get(k1, [])) == len(mapping_dict.get(k2, [])) for k1, k2 in keys_to_compare):
314
+ logger.warning('The user-defined mapping table is incorrect,\
315
+ make sure that the number of parameters is equal')
316
+ continue
317
+
318
+ ms_api, pt_api = mapping_dict.get('ms_api'), mapping_dict.get('pt_api')
319
+ if ms_api not in ms_api_indices_dict or pt_api not in pt_api_indices_dict:
320
+ continue
321
+ for index in ms_api_indices_dict.get(ms_api):
322
+ op_name = npu_df.loc[index, CompareConst.OP_NAME].replace(ms_api, pt_api, 1)
323
+ if CompareConst.INPUT_PATTERN in op_name:
324
+ is_abandoned = gen_input_compare_key(CompareConst.INPUT_PATTERN, 'args')
325
+ elif CompareConst.KWARGS_PATTERN in op_name:
326
+ is_abandoned = gen_input_compare_key(CompareConst.KWARGS_PATTERN, 'args')
327
+ elif CompareConst.OUTPUT_PATTERN in op_name:
328
+ is_abandoned = gen_input_compare_key(CompareConst.OUTPUT_PATTERN, 'output')
329
+ elif CompareConst.PARAMS_PATTERN in op_name:
330
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_PATTERN, 'parameters')
331
+ elif CompareConst.PARAMS_GRAD_PATTERN in op_name:
332
+ is_abandoned = gen_input_compare_key(CompareConst.PARAMS_GRAD_PATTERN, 'parameters_grad')
333
+ else:
334
+ logger.error(f'Excepted op_name: {op_name}')
335
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
336
+ if is_abandoned:
337
+ npu_df.loc[index, CompareConst.COMPARE_KEY] = op_name + 'abandoned'
338
+
339
+ def gen_data_df(self, data_json, stack_json_data):
340
+ result = {
341
+ CompareConst.OP_NAME: [],
342
+ Const.DTYPE: [],
343
+ Const.SHAPE: [],
344
+ Const.SUMMARY: [],
345
+ 'stack_info': []
346
+ }
347
+ if self.dump_mode == Const.ALL:
348
+ result['data_name'] = []
349
+ elif self.dump_mode == Const.MD5:
350
+ result[Const.MD5] = []
351
+ for data_name in data_json['data']:
352
+ check_op_str_pattern_valid(data_name)
353
+ merge_list = self.gen_merge_list(data_json, data_name, stack_json_data)
354
+ if not merge_list:
271
355
  continue
272
- if flag == 'forward':
273
- input_args = value.get('input_args', {})
274
- else:
275
- input_args = value.get('input', {})
276
- output_args = value.get('output', {})
277
- for i in range(len(input_args)):
278
- inputs_name.append(f"{key}.input.{i}")
279
- for i in range(len(output_args)):
280
- outputs_name.append(f"{key}.output.{i}")
281
- return inputs_name, outputs_name
282
-
283
-
284
- def generate_file_mapping(npu_json_path, bench_json_path, mapping_list):
285
-
286
- npu_data = load_json(npu_json_path).get("data", {})
287
- bench_data = load_json(bench_json_path).get("data", {})
288
-
289
- forward_data = []
290
- mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.FORWARD)
291
- for map_value in mapping_list:
292
- npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "forward")
293
- bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "forward")
294
- inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
295
- outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
296
- forward_data.extend(inputs_zip)
297
- forward_data.extend(outputs_zip)
298
-
299
- backward_data = []
300
- mapping_list = sort_by_execution_sequence(npu_data, bench_data, mapping_list, Const.BACKWARD)
301
- for map_value in mapping_list:
302
- npu_forward_inputs, npu_backward_outputs = generate_kernel_data(map_value[0], npu_data, "backward")
303
- bench_forward_inputs, bench_backward_outputs = generate_kernel_data(map_value[1], bench_data, "backward")
304
- inputs_zip = list(zip_longest(npu_forward_inputs, bench_forward_inputs))
305
- outputs_zip = list(zip_longest(npu_backward_outputs, bench_backward_outputs))
306
- backward_data.extend(inputs_zip)
307
- backward_data.extend(outputs_zip)
308
-
309
- kernel_data = forward_data + backward_data
310
- result = {key: value for key, value in kernel_data if key is not None}
311
-
312
- return result
356
+
357
+ op_name_list = merge_list.get(CompareConst.OP_NAME)
358
+ summary_list = merge_list.get(Const.SUMMARY)
359
+ data_name_list = merge_list.get('data_name')
360
+ op_name_reorder, summary_reorder, data_name_reorder = reorder_op_x_list(op_name_list,
361
+ summary_list,
362
+ data_name_list)
363
+ for op_name in op_name_reorder:
364
+ result[CompareConst.OP_NAME].append(op_name)
365
+ if (CompareConst.INPUT_PATTERN in op_name) or (CompareConst.KWARGS_PATTERN in op_name):
366
+ struct = merge_list[CompareConst.INPUT_STRUCT].pop(0)
367
+ elif CompareConst.OUTPUT_PATTERN in op_name:
368
+ struct = merge_list[CompareConst.OUTPUT_STRUCT].pop(0)
369
+ elif CompareConst.PARAMS_PATTERN in op_name:
370
+ struct = merge_list[CompareConst.PARAMS_STRUCT].pop(0)
371
+ else:
372
+ struct = merge_list[CompareConst.PARAMS_GRAD_STRUCT].pop(0)
373
+ result[Const.DTYPE].append(struct[0])
374
+ result[Const.SHAPE].append(struct[1])
375
+ if self.dump_mode == Const.MD5:
376
+ result[Const.MD5].append(struct[2])
377
+ result[Const.SUMMARY].append(summary_reorder.pop(0))
378
+ result['stack_info'].append(merge_list['stack_info'][0] if self.stack_mode else None)
379
+ if self.dump_mode == Const.ALL:
380
+ result['data_name'].append(data_name_reorder.pop(0))
381
+ return pd.DataFrame(result)
313
382
 
314
383
 
315
384
  def check_cross_framework(bench_json_path):
@@ -323,35 +392,31 @@ def check_cross_framework(bench_json_path):
323
392
 
324
393
  def ms_compare(input_param, output_path, **kwargs):
325
394
  try:
326
- stack_mode = kwargs.get('stack_mode', False)
327
395
  auto_analyze = kwargs.get('auto_analyze', True)
328
396
  fuzzy_match = kwargs.get('fuzzy_match', False)
329
397
  cell_mapping = kwargs.get('cell_mapping', None)
330
398
  api_mapping = kwargs.get('api_mapping', None)
331
399
  data_mapping = kwargs.get('data_mapping', None)
332
400
  layer_mapping = kwargs.get('layer_mapping', None)
401
+ suffix = kwargs.get('suffix', '')
333
402
 
334
- summary_compare, md5_compare = task_dumppath_get(input_param)
403
+ set_dump_path(input_param)
404
+ dump_mode = get_dump_mode(input_param)
405
+ if 'stack_json_path' in input_param:
406
+ stack_mode = kwargs.get('stack_mode', False)
407
+ else:
408
+ stack_mode = set_stack_json_path(input_param) # set stack_mode and set "stack_json_path" in input_param
335
409
  check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
336
410
  create_directory(output_path)
337
- check_compare_param(input_param, output_path, summary_compare, md5_compare)
411
+ check_compare_param(input_param, output_path, dump_mode, stack_mode)
338
412
  except (CompareException, FileCheckException) as error:
339
413
  logger.error('Compare failed. Please check the arguments and do it again!')
340
414
  raise CompareException(error.code) from error
341
415
  if layer_mapping:
342
- pt_stack, pt_construct = struct_json_get(input_param, Const.PT_FRAMEWORK)
343
- ms_stack, ms_construct = struct_json_get(input_param, Const.MS_FRAMEWORK)
344
- mapping = load_yaml(layer_mapping)
345
- ms_mapping_result = modify_mapping_with_stack(ms_stack, ms_construct)
346
- pt_mapping_result = modify_mapping_with_stack(pt_stack, pt_construct)
347
- layer_mapping = get_layer_mapping(ms_mapping_result, pt_mapping_result, mapping)
348
- data_mapping = generate_file_mapping(input_param.get("npu_json_path"), input_param.get("bench_json_path"), layer_mapping)
349
-
350
- data_mapping_name = add_time_with_yaml(f"data_mapping")
351
- data_mapping_path = os.path.join(os.path.realpath(output_path), f"{data_mapping_name}")
352
- save_yaml(data_mapping_path, data_mapping)
353
- is_cross_framework = check_cross_framework(input_param.get("bench_json_path"))
354
- ms_comparator = MSComparator(cell_mapping, api_mapping, data_mapping, is_cross_framework)
355
- ms_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
356
- auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
357
- md5_compare=md5_compare)
416
+ data_mapping = generate_data_mapping_by_layer_mapping(input_param, layer_mapping, output_path)
417
+
418
+ mode_config = ModeConfig(stack_mode, auto_analyze, fuzzy_match, dump_mode)
419
+ mapping_config = MappingConfig(cell_mapping, api_mapping, data_mapping)
420
+ is_cross_framework = check_cross_framework(input_param.get('bench_json_path'))
421
+ ms_comparator = MSComparator(mode_config, mapping_config, is_cross_framework)
422
+ ms_comparator.compare_core(input_param, output_path, suffix=suffix)