mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,246 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import re
18
+ from copy import deepcopy
19
+ from dataclasses import dataclass
20
+ from typing import ClassVar, Dict, List, Optional, Tuple
21
+
22
+ import yaml
23
+ from msprobe.core.common.const import Const
24
+ from msprobe.core.common.file_utils import save_yaml
25
+ from msprobe.core.common.log import logger
26
+ from msprobe.core.common.utils import CompareException, add_time_with_yaml
27
+ from msprobe.core.compare.layer_mapping.postprocess_pass import postprocess_pass
28
+
29
+
30
+ @dataclass
31
+ class DumpDataItem:
32
+ framework: str
33
+ data_name: Optional[str] = None
34
+ api_type: Optional[str] = None
35
+ api_name: Optional[str] = None
36
+ type_name: Optional[str] = None
37
+ full_scope: str = ""
38
+ layer_scope: str = ""
39
+ stack_scope: str = ""
40
+ frame_stack_scope: str = ""
41
+ user_stack_scope: str = ""
42
+ construct_scope: str = ""
43
+ scope_direction: Optional[str] = None
44
+ scope_id: Optional[int] = None
45
+ state: str = ""
46
+
47
+ # 类变量使用 ClassVar
48
+ layernames: ClassVar[set] = {Const.CELL, Const.MODULE}
49
+ framework2stack_sign: ClassVar[Dict[str, Tuple[str, str]]] = {
50
+ Const.MS_FRAMEWORK: ("Template", "construct"),
51
+ Const.PT_FRAMEWORK: ("Template", r"in (for|back)ward,")
52
+ }
53
+
54
+ @staticmethod
55
+ def check_stack_valid(stack_info):
56
+ if stack_info is not None:
57
+ if not isinstance(stack_info, list):
58
+ logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
59
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
60
+ for stack in stack_info:
61
+ if not isinstance(stack, str):
62
+ logger.error(f"stack is invalid, it should be a list[str], but got {stack_info}")
63
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
64
+
65
+ def set(self, data_name: str, construct_info: str, stack_info: str) -> None:
66
+ self.set_name(data_name)
67
+ self.set_layer_scope(construct_info)
68
+ self.set_stack_scope(stack_info)
69
+ self.set_full_scope()
70
+
71
+ def set_name(self, data_name: str) -> None:
72
+ self.data_name = data_name
73
+ data_name_list = data_name.split(Const.SEP)
74
+ if not data_name_list or len(data_name_list) < abs(Const.LAYER_NAME_INDEX):
75
+ logger.error(
76
+ f"The dump data does not comply with the format specification and "
77
+ f"must contain no less than four fields. "
78
+ f"The current data is {data_name}"
79
+ )
80
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
81
+
82
+ if data_name_list[Const.LAST_INDEX] == Const.PARAMS_GRAD:
83
+ self.api_type = Const.PARAMS_GRAD
84
+ self.api_name = data_name_list[Const.PARAMS_GRAD_NAME_INDEX]
85
+ self.type_name = data_name_list[Const.PARAMS_GRAD_TYPE_NAME_INDEX]
86
+ self.state = Const.PARAMS_GRAD
87
+ return
88
+
89
+ self.api_type = data_name_list[Const.API_TYPE_INDEX]
90
+ self.type_name = data_name_list[Const.TYPE_NAME_INDEX]
91
+ if self.api_type in self.layernames:
92
+ self.api_name = data_name_list[Const.LAYER_NAME_INDEX]
93
+ self.state = data_name_list[Const.SCOPE_DIRECTION_INDEX]
94
+ else:
95
+ self.api_name = self.type_name
96
+ self.state = data_name_list[Const.LAST_INDEX]
97
+
98
+ def set_layer_scope(self, construct_info: str) -> None:
99
+ self.construct_scope = construct_info
100
+ if self.api_type in self.layernames:
101
+ # remove api name
102
+ data_list = self.data_name.split(Const.SEP)
103
+ data_list = data_list[:Const.LAYER_NAME_INDEX] + data_list[Const.TYPE_NAME_INDEX:]
104
+ elif self.api_type == Const.PARAMS_GRAD:
105
+ data_list = self.data_name.split(Const.SEP)
106
+ elif construct_info:
107
+ data_list = construct_info.split(Const.SEP)
108
+ else:
109
+ data_list = []
110
+
111
+ if data_list:
112
+ self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
113
+ else:
114
+ self.layer_scope = Const.TOP_LAYER
115
+ if construct_info:
116
+ construct_list = construct_info.split(Const.SEP)
117
+ if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
118
+ logger.error(
119
+ f"The construct data does not comply with the format specification and "
120
+ f"must contain no less than four fields. "
121
+ f"The current data is {construct_info}"
122
+ )
123
+ raise CompareException(CompareException.INVALID_DATA_ERROR)
124
+ self.scope_id = construct_list[Const.SCOPE_ID_INDEX]
125
+ self.scope_direction = construct_list[Const.SCOPE_DIRECTION_INDEX]
126
+
127
+ def set_stack_scope(self, stack_info: str) -> None:
128
+ # Cell/Module has no stack info
129
+ if self.api_type in self.layernames:
130
+ return
131
+
132
+ if self.api_type in Const.DATA_TYPE_SKIP_LIST or not stack_info:
133
+ return
134
+
135
+ start_sign, end_sign = self.framework2stack_sign.get(self.framework)
136
+ self.check_stack_valid(stack_info)
137
+ start_pos, end_pos = find_regard_scope(stack_info, start_sign, end_sign)
138
+ # 获取指定范围的代码
139
+ regard_scope = stack_info[start_pos + 1:end_pos]
140
+ frame_func_stack_list, user_func_stack_list = find_stack_func_list(regard_scope)
141
+ self.frame_stack_scope = Const.SEP.join(frame_func_stack_list)
142
+ self.user_stack_scope = Const.SEP.join(user_func_stack_list)
143
+
144
+ def set_full_scope(self, use_user_func_scope=False, use_frame_func_scope=True) -> None:
145
+ scope_list = [self.layer_scope]
146
+ if use_user_func_scope and self.user_stack_scope:
147
+ scope_list.append(self.user_stack_scope)
148
+ if use_frame_func_scope and self.frame_stack_scope:
149
+ scope_list.append(self.frame_stack_scope)
150
+ scope_list.append(self.api_name)
151
+ self.full_scope = Const.SEP.join(scope_list)
152
+
153
+
154
+ def find_regard_scope(lines, start_sign, end_sign):
155
+ # 找出 start_pos 和 end_pos
156
+ start_pos = -1
157
+ end_pos = len(lines)
158
+ for idx, ii in enumerate(lines):
159
+ if re.search(start_sign, ii):
160
+ start_pos = idx
161
+ elif start_pos >= 0 and re.search(end_sign, ii):
162
+ end_pos = idx
163
+ break
164
+ return start_pos, end_pos
165
+
166
+
167
+ def find_stack_func_list(lines, record_user=True):
168
+ res_list = []
169
+ user_stack = []
170
+ frame_stack = None
171
+ no_entrance = True
172
+ for line in lines:
173
+ ele_list = line.split(Const.COMMA)
174
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
175
+ # if framework func line and no framework entrance found yet
176
+ if any(ii in file_ele for ii in Const.FRAME_FILE_LIST) and no_entrance:
177
+ frame_stack = line # Update the last target index
178
+ else:
179
+ if record_user:
180
+ user_stack.append(line)
181
+ no_entrance = False
182
+
183
+ # Check if the last string in the list contains target str
184
+ if frame_stack and no_entrance:
185
+ no_entrance = False
186
+
187
+ # 过滤和处理 regard_scope
188
+ frame_func = get_stack_in_lines([frame_stack])
189
+ user_func = get_stack_in_lines(user_stack)
190
+ return (frame_func, user_func)
191
+
192
+
193
+ def get_stack_in_lines(simplified: List[str]):
194
+ res_list = []
195
+ if not simplified:
196
+ return res_list
197
+ for line in simplified:
198
+ if not line:
199
+ continue
200
+
201
+ ele_list = line.split(Const.COMMA)
202
+ file_ele = ele_list[Const.STACK_FILE_INDEX]
203
+ if any(ii in file_ele for ii in Const.FILE_SKIP_LIST):
204
+ continue
205
+
206
+ func_ele = ele_list[Const.STACK_FUNC_INDEX]
207
+ if any(ii in func_ele for ii in Const.FUNC_SKIP_LIST):
208
+ continue
209
+
210
+ in_func_name = func_ele.split()[Const.STACK_FUNC_ELE_INDEX]
211
+
212
+ res_list.append(in_func_name)
213
+
214
+ reversed_list = res_list[::-1]
215
+ return reversed_list
216
+
217
+
218
+ def dumpdata_representer(dumper, data):
219
+ d = deepcopy(data.__dict__)
220
+ d.pop("data_name")
221
+ return dumper.represent_dict(d)
222
+
223
+
224
+ def get_dump_data_items(dump, stack, construct, framework, output_path=None):
225
+ if not stack or not construct:
226
+ return []
227
+ name2item = {}
228
+ data_items = []
229
+
230
+ dump_data = dump.get("data", {})
231
+ for data_name in dump_data:
232
+ code_info = stack.get(data_name, None)
233
+ parent_info = construct.get(data_name, None)
234
+ data_item = DumpDataItem(framework)
235
+ data_item.set(data_name, parent_info, code_info)
236
+ name2item[data_name] = data_item
237
+ data_items.append(data_item)
238
+
239
+ postprocess_pass(data_items, name2item)
240
+
241
+ if output_path:
242
+ yaml.add_representer(DumpDataItem, dumpdata_representer)
243
+ file_name = add_time_with_yaml(f"{framework}_data")
244
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
245
+ save_yaml(file_path, name2item)
246
+ return data_items
@@ -0,0 +1,249 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from collections import defaultdict
18
+
19
+ from msprobe.core.common.const import CompareConst, Const
20
+ from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
21
+ from msprobe.core.common.utils import (add_time_with_yaml,
22
+ detect_framework_by_dump_json,
23
+ get_stack_construct_by_dump_json_path)
24
+ from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
25
+ from msprobe.core.compare.utils import read_op, reorder_op_name_list
26
+
27
+
28
+
29
+ class LayerTrie:
30
+ def __init__(self, type_name, framework=None):
31
+ self.type_name = type_name
32
+ self.data_items = defaultdict(list)
33
+ self.children = {}
34
+ self.framework = framework
35
+
36
+ def __repr__(self):
37
+ data_nums = [{k: len(v)} for k, v in self.data_items.items()]
38
+ return f"Layer(type_name={self.type_name}, data_number={data_nums})"
39
+
40
+ def get(self, name):
41
+ return self.children.get(name)
42
+
43
+ def insert(self, data_item):
44
+ parts = data_item.full_scope.split(Const.SEP)
45
+ node = self
46
+ scope_name_list = parts[Const.RIGHT_MOVE_INDEX:]
47
+
48
+ for name in scope_name_list:
49
+ if name not in node.children:
50
+ node.children[name] = LayerTrie(name, data_item.framework)
51
+ node = node.children[name]
52
+ node.data_items[data_item.state].append(data_item)
53
+ node.type_name = data_item.type_name
54
+
55
+ def query_data(self, scope, state, index, default_value=None):
56
+ parts = scope.split(Const.SEP)
57
+ node = self
58
+ scope_name_list = parts[1:]
59
+
60
+ for name in scope_name_list:
61
+ if name not in node.children:
62
+ return default_value
63
+ node = node.children[name]
64
+ if index >= len(node.data_items[state]):
65
+ return default_value
66
+ return node.data_items[state][index]
67
+
68
+ def save_to_yaml(self, output_path):
69
+ result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
70
+ file_name = add_time_with_yaml(f"{self.framework}_tree")
71
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
72
+ save_yaml(file_path, result)
73
+
74
+ def convert_to_dict(self, node):
75
+ result = {}
76
+ result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
77
+ for child_key, child_node in node.children.items():
78
+ key = f"{child_key} @ {child_node}"
79
+ result[key] = self.convert_to_dict(child_node)
80
+ return result
81
+
82
+
83
+ def convert_scope(layer_trie, data_item, mapping=None):
84
+ if not mapping:
85
+ mapping = {}
86
+ new_scope = Const.TOP_LAYER
87
+ scope_list = data_item.full_scope.split(Const.SEP)
88
+ cur_node = layer_trie
89
+
90
+ idx = 0
91
+ while idx < len(scope_list) - 1:
92
+ child_name = scope_list[idx + 1]
93
+ type_name = cur_node.type_name
94
+ prefix_mapping = mapping.get(type_name, {})
95
+ mapping_list = prefix_mapping.get(child_name, [])
96
+ mapping_list.append((child_name, child_name, 1))
97
+ step = 1
98
+ for origin, target, level in mapping_list:
99
+ if Const.SEP.join(scope_list[idx + 1: idx + level + 1]) == origin:
100
+ new_scope = new_scope + Const.SEP + target
101
+ step = level
102
+ break
103
+ for _ in range(step):
104
+ child_node = cur_node.get(scope_list[idx + 1])
105
+ cur_node = child_node
106
+ idx += 1
107
+ index = -1
108
+ state = data_item.state
109
+ for idx, child in enumerate(cur_node.data_items[state]):
110
+ if data_item.data_name == child.data_name:
111
+ index = idx
112
+ return new_scope, state, index
113
+
114
+
115
+ def get_data_items_and_tree(dump_json_path, output_path):
116
+ framework = detect_framework_by_dump_json(dump_json_path)
117
+ stack, construct = get_stack_construct_by_dump_json_path(dump_json_path)
118
+ dump = load_json(dump_json_path)
119
+ dump_data_items = get_dump_data_items(dump, stack, construct, framework, output_path)
120
+ root = LayerTrie(Const.TOP_LAYER, framework)
121
+ for data_item in dump_data_items:
122
+ root.insert(data_item)
123
+ if output_path:
124
+ root.save_to_yaml(output_path)
125
+ return dump_data_items, root
126
+
127
+
128
+ def convert_data_item(npu_tree, bench_tree, npu_data_item, mapping):
129
+ new_scope, state, index = convert_scope(npu_tree, npu_data_item, mapping)
130
+ bench_data_item = bench_tree.query_data(new_scope, state, index)
131
+ return bench_data_item
132
+
133
+
134
+ def update_keys_in_place(d):
135
+ """
136
+ This function is used to compare and maintain compatibility between the old and new versions.
137
+ In the old version, 'Cell' was used as the top layer name, while the new version uses 'TopLayer'.
138
+ """
139
+ cell_value = d.pop(Const.CELL, None)
140
+
141
+ if cell_value is not None:
142
+ d[Const.TOP_LAYER] = cell_value
143
+
144
+
145
+ def preprocess_layer_mapping(mapping):
146
+ """
147
+ before:
148
+ {'A': {'a.b.c': 'new_c',
149
+ 'a.demo': 'new_demo',
150
+ 'z': 'new_z',
151
+ 'd.e': 'e'}}
152
+ after:
153
+ {'A': {'a': [('a.b.c', 'new_c', 3), ('a.demo', 'new_demo', 2)],
154
+ 'z': [('z', 'new_z', 1)],
155
+ 'd': [('d.e', 'e', 2)]}}
156
+ """
157
+ update_keys_in_place(mapping)
158
+ final_mapping = {}
159
+
160
+ for type_name, name_map in mapping.items():
161
+ final_mapping[type_name] = {}
162
+
163
+ for key, value in name_map.items():
164
+ key_list = key.split('.')
165
+ prefix = key_list[0] # 取前缀
166
+ key_len = len(key_list)
167
+ if prefix not in final_mapping[type_name]:
168
+ final_mapping[type_name][prefix] = []
169
+ final_mapping[type_name][prefix].append((key, value, key_len))
170
+
171
+ # 前缀映射列表按规则长度排序
172
+ for prefix in final_mapping[type_name]:
173
+ final_mapping[type_name][prefix].sort(key=lambda x: -x[-1])
174
+
175
+ return final_mapping
176
+
177
+
178
+ def convert_data_items(npu_tree, bench_tree, npu_data_items, mapping):
179
+ mapping = preprocess_layer_mapping(mapping)
180
+ api_mapping = {}
181
+ for npu_data_item in npu_data_items:
182
+ bench_data_item = convert_data_item(npu_tree, bench_tree, npu_data_item, mapping)
183
+ bench_name = bench_data_item.data_name if bench_data_item else CompareConst.N_A
184
+ npu_name = npu_data_item.data_name
185
+ api_mapping[npu_name] = bench_name
186
+ return api_mapping
187
+
188
+
189
+ def generate_api_mapping_by_layer_mapping(npu_json_path, bench_json_path, layer_mapping_path=None, output_path=None):
190
+ npu_data_items, npu_root = get_data_items_and_tree(npu_json_path, output_path)
191
+ _, bench_root = get_data_items_and_tree(bench_json_path, output_path)
192
+ if isinstance(layer_mapping_path, str):
193
+ mapping = load_yaml(layer_mapping_path)
194
+ else:
195
+ mapping = {}
196
+ api_mapping = convert_data_items(npu_root, bench_root, npu_data_items, mapping)
197
+ if output_path:
198
+ file_name = add_time_with_yaml("api_mapping")
199
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
200
+ save_yaml(file_path, api_mapping)
201
+ return api_mapping
202
+
203
+
204
+ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_path=None):
205
+ def read_full_op_names(data, op_name):
206
+ op_parsed_list = read_op(data.get(op_name, {}), op_name)
207
+ full_op_names = [op_parsed.get('full_op_name') for op_parsed in op_parsed_list]
208
+ return full_op_names
209
+
210
+ def generate_op_data_mapping(npu_op_name, npu_full_op_names, bench_op_name, bench_full_op_names):
211
+ suffix_to_full_op_name = {}
212
+ op_data_mapping = {}
213
+ for bench_full_op_name in bench_full_op_names:
214
+ suffix = bench_full_op_name[len(bench_op_name):]
215
+ suffix_to_full_op_name[suffix] = bench_full_op_name
216
+
217
+ for npu_full_op_name in npu_full_op_names:
218
+ suffix = npu_full_op_name[len(npu_op_name):]
219
+ op_data_mapping[npu_full_op_name] = suffix_to_full_op_name.get(suffix, CompareConst.N_A)
220
+ return op_data_mapping
221
+
222
+ npu_data = load_json(npu_json_path).get("data", {})
223
+ bench_data = load_json(bench_json_path).get("data", {})
224
+ data_mapping = {}
225
+ for npu_op_name, bench_op_name in api_mapping.items():
226
+ if not npu_op_name:
227
+ continue
228
+ npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
229
+ bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
230
+ npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names)
231
+ bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names)
232
+ mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder,
233
+ bench_op_name, bench_full_op_names_reorder)
234
+ data_mapping.update(mapping)
235
+ if output_path:
236
+ file_name = add_time_with_yaml("data_mapping")
237
+ file_path = os.path.join(os.path.realpath(output_path), file_name)
238
+ save_yaml(file_path, data_mapping)
239
+ return data_mapping
240
+
241
+
242
+ def generate_data_mapping_by_layer_mapping(input_param, layer_mapping_path=None, output_path=None):
243
+ npu_json_path = input_param.get("npu_json_path")
244
+ bench_json_path = input_param.get("bench_json_path")
245
+ api_mapping = generate_api_mapping_by_layer_mapping(
246
+ npu_json_path, bench_json_path, layer_mapping_path)
247
+ data_mapping = generate_data_mapping(
248
+ npu_json_path, bench_json_path, api_mapping, output_path)
249
+ return data_mapping
@@ -0,0 +1,95 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+ import math
17
+
18
+ from msprobe.core.common.const import Const
19
+
20
+
21
+ def postprocess_pass(data_items, name2item):
22
+ backward_pass(data_items, name2item)
23
+ renumber_index_pass(data_items, "ParallelTransformer", "layers")
24
+
25
+
26
+ def backward_pass(data_items, name2item):
27
+ # 处理反向数据,反向无栈信息,沿用正向数据栈信息
28
+ for data_item in data_items:
29
+ data_name_list = data_item.data_name.split(Const.SEP)
30
+ if not data_name_list:
31
+ continue
32
+ if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX:]:
33
+ data_name_list[Const.SCOPE_DIRECTION_INDEX:] = [
34
+ s.replace(Const.BACKWARD, Const.FORWARD)
35
+ for s in data_name_list[Const.SCOPE_DIRECTION_INDEX:]
36
+ ]
37
+ forward_name = Const.SEP.join(data_name_list)
38
+ forward_item = name2item.get(forward_name, None)
39
+ if not forward_item:
40
+ continue
41
+ data_item.stack_scope = forward_item.stack_scope
42
+ data_item.full_scope = forward_item.full_scope
43
+ data_item.layer_scope = forward_item.layer_scope
44
+
45
+
46
+ def extract_next_item_last_number(data, prefix, default_result=None):
47
+ result = default_result
48
+ match = re.search(rf"^{re.escape(prefix)}\.(\S+?)(?:\.|$)", data)
49
+ if match:
50
+ next_item = match.group(1)
51
+ numbers = re.findall(r"\d+", next_item)
52
+ if numbers:
53
+ result = int(numbers[-1])
54
+ return result
55
+
56
+
57
+ def replace_next_item_index(full_scope, prefix, index):
58
+ if math.isinf(index):
59
+ return full_scope
60
+ prefix_pattern = rf"^{re.escape(prefix)}\."
61
+ result = full_scope
62
+ match = re.search(rf"{prefix_pattern}(\S+?)(?:\.|$)", full_scope)
63
+ if match:
64
+ next_item = match.group(1)
65
+ pattern = rf"{prefix_pattern}{re.escape(next_item)}"
66
+ result = re.sub(pattern, f"{prefix}.{index}", full_scope, count=1)
67
+ return result
68
+
69
+
70
+ def renumber_index_pass(data_items, type_name, suffix=None):
71
+ """
72
+ 该函数为解决并行切分场景中编号不一致的比对问题。例如在MindSpore中ParallelTransformer层的PP切分场景,
73
+ MindSpore中的layers的成员编号是全局的,而在PyTorch中编号为局部的。
74
+ 为适配此种场景,对指定层的索引进行重新编号,以确保在后续处理阶段序号对齐。
75
+ """
76
+ prefix_dict = {} # 保存类型为type_name的前缀和最小编号的映射
77
+ for data_item in data_items:
78
+ if data_item.type_name == type_name:
79
+ prefix = f"{data_item.full_scope}.{suffix}" if suffix else data_item.layer_scope
80
+ prefix_dict[prefix] = math.inf
81
+
82
+ # 计算前缀对应的最小编号
83
+ for prefix in prefix_dict:
84
+ for data_item in data_items:
85
+ res = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
86
+ prefix_dict[prefix] = min(prefix_dict[prefix], res)
87
+
88
+ # 重新编号
89
+ for prefix, min_index in prefix_dict.items():
90
+ for data_item in data_items:
91
+ full_scope = data_item.full_scope
92
+ abs_index = extract_next_item_last_number(data_item.full_scope, prefix, math.inf)
93
+ rel_index = abs_index - min_index
94
+ full_scope = replace_next_item_index(full_scope, prefix, rel_index)
95
+ data_item.full_scope = full_scope