mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,197 @@
1
+ # Copyright (c) 2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import math
18
+ from msprobe.core.common.const import CompareConst, Const
19
+ from msprobe.visualization.utils import ToolTip, GraphConst, str2float
20
+
21
+
22
+ class ModeAdapter:
23
+ def __init__(self, compare_mode):
24
+ self.compare_mode = compare_mode
25
+ self.csv_data = []
26
+ self.compare_nodes = []
27
+
28
+ @staticmethod
29
+ def _add_md5_compare_data(node_data, compare_data_dict):
30
+ precision_index = GraphConst.MAX_INDEX_KEY
31
+ for key, value in node_data.items():
32
+ if not isinstance(value, dict):
33
+ continue
34
+ compare_data = compare_data_dict.get(key)
35
+ if compare_data:
36
+ headers = CompareConst.MD5_COMPARE_RESULT_HEADER
37
+ id_list = [headers.index(x) for x in GraphConst.MD5_INDEX_LIST]
38
+ ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list)
39
+ # md5比对是否通过
40
+ if value.get(CompareConst.RESULT) != CompareConst.PASS:
41
+ precision_index = GraphConst.MIN_INDEX_KEY
42
+ node_data[key] = value
43
+ return precision_index
44
+
45
+ @staticmethod
46
+ def _add_real_compare_data(node_data, compare_data_dict):
47
+ min_thousandth = float(1)
48
+ numbers = []
49
+ for key, value in node_data.items():
50
+ if not isinstance(value, dict):
51
+ continue
52
+ compare_data = compare_data_dict.get(key)
53
+ if compare_data:
54
+ headers = CompareConst.COMPARE_RESULT_HEADER
55
+ id_list = [headers.index(x) for x in GraphConst.REAL_DATA_INDEX_LIST]
56
+ ModeAdapter._match_data(value, compare_data, GraphConst.REAL_DATA_INDEX_LIST, id_list)
57
+ # 跳过scalar data,因为无法计算双千指标,会得到Nan
58
+ if not value.get(Const.SHAPE):
59
+ continue
60
+ # 获取一个节点所有的输入或输出最小的双千指标
61
+ thousandth = value.get(CompareConst.ONE_THOUSANDTH_ERR_RATIO)
62
+ # 可能是None,可能是非数字内容str
63
+ try:
64
+ thousandth = float(thousandth)
65
+ except (ValueError, TypeError):
66
+ thousandth = None
67
+ if thousandth is not None:
68
+ numbers.append(thousandth)
69
+ node_data[key] = value
70
+ # 双千指标都是None的异常情况
71
+ if not numbers:
72
+ min_thousandth = None
73
+ else:
74
+ min_thousandth = min(numbers + [min_thousandth])
75
+ return min_thousandth
76
+
77
+ @staticmethod
78
+ def _add_summary_compare_data(node_data, compare_data_dict):
79
+ max_relative_err = GraphConst.MIN_INDEX_KEY
80
+ # data_info: {'type': 'torch.Tensor', 'dtype': 'torch.float32', 'shape': [2, 536320], 'Max': 9.66036224, ...}
81
+ for key, data_info in node_data.items():
82
+ if not isinstance(data_info, dict):
83
+ continue
84
+ compare_data = compare_data_dict.get(key)
85
+ if compare_data:
86
+ # 对应比对结果csv的列
87
+ key_list = GraphConst.SUMMARY_INDEX_LIST
88
+ headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
89
+ id_list = [headers.index(x) for x in key_list]
90
+ ModeAdapter._match_data(data_info, compare_data, key_list, id_list)
91
+ for item in key_list[4:]:
92
+ relative_err = str2float(data_info.get(item))
93
+ max_relative_err = max(max_relative_err, relative_err)
94
+ node_data[key] = data_info
95
+ max_relative_err = 1 if max_relative_err > 1 else max_relative_err
96
+ return max_relative_err
97
+
98
+ @staticmethod
99
+ def _match_data(data_dict, compare_data, key_list, id_list):
100
+ """
101
+ 绑定精度指标到node的input_data和output_data
102
+ """
103
+ if len(key_list) != len(id_list):
104
+ return
105
+ for id_val, key in zip(id_list, key_list):
106
+ data_dict[key] = compare_data[id_val]
107
+
108
+ @staticmethod
109
+ def _check_list_len(data_list, len_num):
110
+ if len(data_list) < len_num:
111
+ raise ValueError(f"compare_data_dict_list must contain at least {len_num} items.")
112
+
113
+ def parse_result(self, node, compare_data_dict_list):
114
+ """
115
+ 根据结果返回数据,分别是precision_index,和附加数据
116
+ """
117
+
118
+ other_dict = {}
119
+ if self.compare_mode == GraphConst.MD5_COMPARE:
120
+ ModeAdapter._check_list_len(compare_data_dict_list, 2)
121
+ precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict_list[0])
122
+ precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict_list[1])
123
+ # 所有输入输出md5对比通过,这个节点才算通过
124
+ precision_index = min(precision_index_in, precision_index_out)
125
+ other_result = CompareConst.PASS if precision_index == GraphConst.MAX_INDEX_KEY else CompareConst.DIFF
126
+ other_dict[CompareConst.RESULT] = other_result
127
+ elif self.compare_mode == GraphConst.SUMMARY_COMPARE:
128
+ ModeAdapter._check_list_len(compare_data_dict_list, 2)
129
+ ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict_list[0])
130
+ precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict_list[1])
131
+ precision_index = precision_index_out
132
+ else:
133
+ ModeAdapter._check_list_len(compare_data_dict_list, 1)
134
+ min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict_list[0])
135
+ min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict_list[0])
136
+ if min_thousandth_in is not None and min_thousandth_out is not None:
137
+ change_percentage = min_thousandth_in - min_thousandth_out
138
+ else:
139
+ change_percentage = GraphConst.MIN_INDEX_KEY
140
+ change_percentage = GraphConst.MIN_INDEX_KEY if change_percentage < GraphConst.MIN_INDEX_KEY \
141
+ else change_percentage
142
+ precision_index = GraphConst.MAX_INDEX_KEY \
143
+ if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
144
+ return precision_index, other_dict
145
+
146
+ def prepare_real_data(self, node):
147
+ """
148
+ 为真实数据比较模式准备节点信息
149
+ """
150
+ if self.compare_mode == GraphConst.REAL_DATA_COMPARE:
151
+ self.compare_nodes.append(node)
152
+ return True
153
+ return False
154
+
155
+ def add_csv_data(self, compare_result_list):
156
+ if self.compare_mode != GraphConst.REAL_DATA_COMPARE:
157
+ return
158
+ self.csv_data.extend(compare_result_list)
159
+
160
+ def add_error_key(self, node_data):
161
+ """
162
+ 根据不同的模式进行提供不同错误信息
163
+ """
164
+ for key, value in node_data.items():
165
+ if not isinstance(value, dict):
166
+ continue
167
+ if self.compare_mode == GraphConst.SUMMARY_COMPARE:
168
+ message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
169
+ CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
170
+ elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
171
+ message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
172
+ else:
173
+ # 输出件优化
174
+ message = []
175
+ value[GraphConst.ERROR_KEY] = message
176
+ node_data[key] = value
177
+
178
+ def get_tool_tip(self):
179
+ """
180
+ 用于前端展示字段的具体含义
181
+ """
182
+ if self.compare_mode == GraphConst.SUMMARY_COMPARE:
183
+ tips = {
184
+ CompareConst.MAX_DIFF: ToolTip.MAX_DIFF,
185
+ CompareConst.MIN_DIFF: ToolTip.MIN_DIFF,
186
+ CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF,
187
+ CompareConst.NORM_DIFF: ToolTip.NORM_DIFF}
188
+ elif self.compare_mode == GraphConst.MD5_COMPARE:
189
+ tips = {Const.MD5: ToolTip.MD5}
190
+ else:
191
+ tips = {
192
+ CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO,
193
+ CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO,
194
+ CompareConst.COSINE: ToolTip.COSINE,
195
+ CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
196
+ CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
197
+ return json.dumps(tips)
@@ -0,0 +1,14 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
@@ -0,0 +1,119 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from msprobe.core.overflow_check.level import OverflowLevel
16
+ from msprobe.visualization.graph.node_op import NodeOp
17
+ from msprobe.visualization.utils import GraphConst
18
+ from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
19
+
20
+
21
+ class BaseNode:
22
+ def __init__(self, node_op, node_id, up_node=None):
23
+ self.op = node_op
24
+ self.id = node_id
25
+ self.data = {}
26
+ self.output_data = {}
27
+ self.input_data = {}
28
+ self.upnode = None
29
+ self.add_upnode(up_node)
30
+ self.subnodes = []
31
+ self.matched_node_link = []
32
+ self.suggestions = {}
33
+ self.stack_info = []
34
+ self.micro_step_id = None
35
+ self.overflow_level = None
36
+ self.matched_distributed = {}
37
+
38
+ def __str__(self):
39
+ info = f'id:\t{self.id}'
40
+ return info
41
+
42
+ def __eq__(self, other):
43
+ """
44
+ 用来判断两个节点是否可以被匹配上,认为结构上是否一致
45
+ """
46
+ if not compare_data(self.input_data, other.input_data):
47
+ return False
48
+ if not compare_data(self.output_data, other.output_data):
49
+ return False
50
+ return True
51
+
52
+ def fuzzy_eq(self, other):
53
+ if not compare_data_fuzzy(self.input_data, other.input_data):
54
+ return False
55
+ if not compare_data_fuzzy(self.output_data, other.output_data):
56
+ return False
57
+ return True
58
+
59
+ def set_input_output(self, input_data, output_data):
60
+ self.input_data = input_data
61
+ self.output_data = output_data
62
+
63
+ def set_overflow_level(self, level):
64
+ if not level or not isinstance(level, OverflowLevel):
65
+ return
66
+ self.overflow_level = level
67
+ self.data[GraphConst.OVERFLOW_LEVEL] = self.overflow_level.value
68
+
69
+ def add_upnode(self, node):
70
+ """
71
+ 绑定upnode,用于对两个节点进行上下级关联
72
+ """
73
+ if not node or node.id == self.id or self.upnode:
74
+ return
75
+ self.upnode = node
76
+ node.subnodes.append(self)
77
+
78
+ def add_link(self, node, ancestors):
79
+ """
80
+ 在节点匹配成功后进行匹配数据的录入
81
+ Args:
82
+ node: 和self相互匹配的节点
83
+ ancestors: 对面节点的祖先信息
84
+ """
85
+ self.matched_node_link = ancestors
86
+ node.matched_node_link = ancestors
87
+
88
+ def to_dict(self):
89
+ """
90
+ 输出数据
91
+ """
92
+ result = {
93
+ 'id': self.id,
94
+ 'node_type': self.op.value,
95
+ 'output_data': format_node_data(self.output_data),
96
+ 'input_data': format_node_data(self.input_data),
97
+ 'upnode': self.upnode.id if self.upnode else 'None',
98
+ 'subnodes': [node.id for node in self.subnodes],
99
+ 'matched_node_link': self.matched_node_link,
100
+ 'suggestions': self.suggestions,
101
+ 'stack_info': self.stack_info
102
+ }
103
+ if self.micro_step_id is not None:
104
+ result['micro_step_id'] = self.micro_step_id
105
+ result['data'] = self.data
106
+ if self.matched_distributed:
107
+ result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
108
+ return result
109
+
110
+ def get_ancestors(self):
111
+ """
112
+ 获取节点所有祖先的列表
113
+ """
114
+ ancestors = []
115
+ current_node = self.upnode
116
+ while current_node:
117
+ ancestors.append(current_node.id)
118
+ current_node = current_node.upnode
119
+ return list(reversed(ancestors))
@@ -0,0 +1,318 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from enum import Enum
16
+ from msprobe.visualization.utils import GraphConst
17
+ from msprobe.core.common.const import Const, CompareConst
18
+ from msprobe.core.common.log import logger
19
+
20
+
21
+ class CommunicationType(Enum):
22
+ """
23
+ 通信类型:发送、接收、发送接收
24
+ """
25
+ SEND = 'send'
26
+ RECEIVE = 'receive'
27
+ SEND_RECEIVE = 'send_receive'
28
+
29
+
30
+ class DistributedType(Enum):
31
+ """
32
+ 分布式类型:点对点通信、集体通信
33
+ """
34
+ P2P = 'p2p'
35
+ COLLECTIVE = 'collective'
36
+
37
+
38
+ CANNOT_MATCH = 'cannot match distributed node in rank'
39
+
40
+
41
+ class DistributedAnalyzer:
42
+
43
+ def __init__(self, graphs: dict, overflow_check: bool):
44
+ self.graphs = graphs
45
+ self.overflow_check = overflow_check
46
+ self.config = {
47
+ # 当前通信api名称: 匹配目标通信api名称, 获取rank信息的位置参数或关键字参数, 通信类型, 分布式类型
48
+ 'send': ['recv', GraphConst.DST, CommunicationType.SEND.value, DistributedType.P2P],
49
+ 'isend': ['irecv', GraphConst.DST, CommunicationType.SEND.value, DistributedType.P2P],
50
+ 'recv': ['send', GraphConst.SRC, CommunicationType.RECEIVE.value, DistributedType.P2P],
51
+ 'irecv': ['isend', GraphConst.SRC, CommunicationType.RECEIVE.value, DistributedType.P2P],
52
+ 'broadcast': ['broadcast', '1', CommunicationType.SEND.value, DistributedType.COLLECTIVE],
53
+ 'scatter': ['scatter', GraphConst.SRC, CommunicationType.SEND.value, DistributedType.COLLECTIVE],
54
+ 'gather': ['gather', GraphConst.DST, CommunicationType.RECEIVE.value, DistributedType.COLLECTIVE],
55
+ 'reduce': ['reduce', '1', CommunicationType.RECEIVE.value, DistributedType.COLLECTIVE]
56
+ }
57
+ self.group_node_mapping = {}
58
+ self._make_group_node_mapping()
59
+
60
+ @staticmethod
61
+ def _get_opposite_communication_type(action):
62
+ if action == CommunicationType.SEND.value:
63
+ return CommunicationType.RECEIVE.value
64
+ elif action == CommunicationType.RECEIVE.value:
65
+ return CommunicationType.SEND.value
66
+ return action
67
+
68
+ @staticmethod
69
+ def _node_output_all_equal(data: dict, target_data: dict):
70
+ keys_to_compare = [Const.DTYPE, Const.SHAPE, Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
71
+ return all(data.get(key) == target_data.get(key) for key in keys_to_compare)
72
+
73
+ @staticmethod
74
+ def _get_target_rank(node, rank, parameter):
75
+ """
76
+ 点对点通信, 从输出数据参数src或dst, 获取通信目标rank
77
+ 一对多通信和多对一通信, 从输出数据参数src或dst或位置参数, 获取发送或接收的rank源头
78
+ :param node: 当前节点
79
+ :param rank: 当前rank
80
+ :param parameter: 输出数据参数
81
+ :return: 目标rank
82
+ """
83
+ target_rank = node.input_data.get(f'{node.id}{GraphConst.INPUT}{parameter}', {}).get('value')
84
+ if target_rank is None:
85
+ logger.warning(f'The parameter {parameter} of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
86
+ return target_rank
87
+
88
+ @staticmethod
89
+ def _get_group_info(node, rank):
90
+ """
91
+ 获取当前通信节点的group参数中的group_ranks和group_id
92
+ :param node: 当前通信节点
93
+ :param rank: 当前rank
94
+ :return: group_ranks和group_id
95
+ """
96
+ group = node.input_data.get(f'{node.id}{GraphConst.INPUT}group', {})
97
+ if not group:
98
+ logger.warning(f'The kwarg group of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
99
+ return None, None
100
+ group_ranks = group.get('group_ranks')
101
+ if not group_ranks:
102
+ logger.warning(f'The group_ranks of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
103
+ return None, None
104
+ group_id = group.get('group_id')
105
+ if not group_id:
106
+ logger.warning(f'The group_id of node {node.id} does not exist, {CANNOT_MATCH}{rank}')
107
+ return None, None
108
+ return group_ranks, group_id
109
+
110
+ def distributed_match(self):
111
+ for rank, graph in self.graphs.items():
112
+ nodes = graph.node_map
113
+ for node_id, node in nodes.items():
114
+ # 不是通信节点或者已经匹配过了
115
+ if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed:
116
+ continue
117
+ api_name, distributed_type = self._get_distributed_name_and_type(node_id)
118
+ if distributed_type == DistributedType.P2P:
119
+ self._p2p_match(node, rank, api_name)
120
+ else:
121
+ self._collective_match(node, rank, api_name)
122
+
123
+ def _make_group_node_mapping(self):
124
+ """
125
+ 建立通信节点的全局唯一标识映射
126
+ key: rank号, value: unique_group_id与node_id之间的映射
127
+ {
128
+ "0": {
129
+ "unique_group_id1": "node_id1",
130
+ "unique_group_id2": "node_id2",
131
+ "node_id1": "unique_group_id1",
132
+ "node_id2": "unique_group_id2"
133
+ },
134
+ "1": {},
135
+ "2": {}
136
+ }
137
+ """
138
+ for rank, graph in self.graphs.items():
139
+ group_count = {}
140
+ group_info = {}
141
+ nodes = graph.node_map
142
+ for node_id, node in nodes.items():
143
+ if not node_id.startswith(Const.DISTRIBUTED):
144
+ continue
145
+ api_name, distributed_type = self._get_distributed_name_and_type(node_id)
146
+ if distributed_type == DistributedType.P2P:
147
+ config_info = self.config.get(api_name)
148
+ target_rank = self._get_target_rank(node, rank, config_info[1])
149
+ if target_rank is None:
150
+ continue
151
+ # p2p通信节点,api名称+传输目标rank作为group_id
152
+ group_id = api_name + Const.RANK + str(target_rank)
153
+ else:
154
+ # 其他通信节点直接获取group_id, 并拼接api名称
155
+ _, group_id = self._get_group_info(node, rank)
156
+ if not group_id:
157
+ continue
158
+ group_id += api_name
159
+ # 同group_id的调用次数累计
160
+ group_count[group_id] = group_count.get(group_id, 0) + 1
161
+ # group_id+同group_id的调用次数作为唯一的unique_group_id
162
+ unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id))
163
+ group_info[unique_group_id] = node_id
164
+ group_info[node_id] = unique_group_id
165
+ self.group_node_mapping[rank] = group_info
166
+
167
+ def _get_distributed_name_and_type(self, node_id):
168
+ if Const.SEP not in node_id:
169
+ raise ValueError(f'Invalid node id {node_id}.')
170
+ api_name = node_id.split(Const.SEP)[1]
171
+ if api_name in self.config:
172
+ return api_name, self.config.get(api_name)[3]
173
+ return api_name, DistributedType.COLLECTIVE
174
+
175
+ def _get_target_node(self, rank, unique_group_id, api_name, target_rank, target_api_name=None):
176
+ """
177
+ 获取名称匹配上的目标节点
178
+ :param rank: 当前rank
179
+ :param unique_group_id: 当前节点唯一group id
180
+ :param api_name: 当前节点的api名称, 例如Distributed.isend.0.forward, api名称为isend
181
+ :param target_rank: 与当前节点产生通信的rank
182
+ :param target_api_name: 与当前节点产生通信的节点api名称, 仅p2p通信需要配置
183
+ :return: 目标节点
184
+ """
185
+ target_graph = self.graphs.get(target_rank)
186
+ if not target_graph:
187
+ logger.warning(f'Graph data does not exist, {CANNOT_MATCH}{target_rank}')
188
+ return None
189
+ target_group_mapping = self.group_node_mapping.get(target_rank)
190
+ # p2p通信,想要获取目标节点,需要替换unique_group_id中的rank和api name,
191
+ # 例如isend发送到rank1,对应的irecv接收自rank0, isend_rank1与irecv_rank0对应
192
+ target_unique_group_id = (unique_group_id
193
+ .replace(Const.RANK + str(target_rank), Const.RANK + str(rank))
194
+ .replace(api_name, target_api_name)) if target_api_name else unique_group_id
195
+ target_node_id = target_group_mapping.get(target_unique_group_id, '')
196
+ target_node = target_graph.node_map.get(target_node_id)
197
+ if not target_node:
198
+ logger.warning(f'Node {target_node_id} does not exist, {CANNOT_MATCH}{target_rank}')
199
+ return None
200
+ return target_node
201
+
202
+ def _add_node_matched_distributed(self, node, target_node, api_name, target_rank, reversal_type=False):
203
+ """
204
+ 给当前节点添加matched_distributed字段信息
205
+ :param node: 当前节点
206
+ :param target_node: 匹配上的目标节点
207
+ :param api_name: 当前节点的api名称
208
+ :param target_rank: 匹配上的目标rank
209
+ :param reversal_type: 是否需要反转通信类型,例如broadcast在rank0通信类型是发送,但在其他rank通信类型是接收
210
+ """
211
+ communications_type = self.config.get(api_name)[2]
212
+ communications_type = self._get_opposite_communication_type(communications_type) if reversal_type \
213
+ else communications_type
214
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
215
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
216
+ matched_distributed = {
217
+ 'communications_type': communications_type,
218
+ 'nodes_info': {target_rank: [str(index), target_node.id]}
219
+ }
220
+ node.matched_distributed = matched_distributed
221
+
222
+ def _p2p_match(self, node, rank, api_name):
223
+ """
224
+ 点对点通信匹配
225
+
226
+ 根据当前点对点通信节点的输出数据中的src或dst参数, 确定目标rank, 并从目标rank中找到对应的点对点通信节点, 校验输出数据是否一致,
227
+ 校验通过则在两个匹配节点增加匹配信息
228
+ Args:
229
+ node: 当前点对点通信节点
230
+ rank: 当前节点所属rank
231
+ api_name: 当前节点的api名称
232
+ Returns:
233
+ """
234
+ config_info = self.config.get(api_name)
235
+ target_api_name = config_info[0]
236
+ #
237
+ target_rank = self._get_target_rank(node, rank, config_info[1])
238
+ if target_rank is None:
239
+ return
240
+ unique_group_id = self.group_node_mapping.get(rank, {}).get(node.id, '')
241
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
242
+ if not target_node:
243
+ return
244
+ target_config_info = self.config.get(target_api_name)
245
+ source_rank = (target_node.input_data.get(f'{target_node.id}{GraphConst.INPUT}{target_config_info[1]}', {})
246
+ .get('value'))
247
+ if source_rank is None:
248
+ logger.warning(
249
+ f'The kwarg {target_config_info[1]} of node {target_node.id} does not exist, '
250
+ f'{CANNOT_MATCH}{source_rank}')
251
+ return
252
+ if source_rank != rank:
253
+ # 点对点通信,待匹配目标节点包含的rank信息要与当前rank一致
254
+ logger.warning(
255
+ f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}, '
256
+ f'but the data shows that {target_node.id} communicates with rank{source_rank}.'
257
+ f'The rank is inconsistent, cannot match distributed node')
258
+ return
259
+
260
+ # 点对点通信,两个匹配节点的输出数据要一致
261
+ if not DistributedAnalyzer._node_output_all_equal(node.output_data.get(node.id + '.output.0'),
262
+ target_node.output_data.get(target_node.id + '.output.0')):
263
+ logger.warning(f'{node.id} output of rank{rank} is different from the {target_node.id} '
264
+ f'output of rank{target_rank}, cannot match distributed node')
265
+ return
266
+
267
+ self._add_node_matched_distributed(node, target_node, api_name, target_rank)
268
+ self._add_node_matched_distributed(target_node, node, target_api_name, rank)
269
+
270
+ def _collective_match(self, node, rank, api_name):
271
+ """
272
+ 集体通信匹配
273
+
274
+ 一对多通信和多对一通信, 需要先获取节点输出数据中的src或dst或位置参数, 确定发送源或接收源, 多对多通信不需要
275
+ :param node: 当前集体通信节点
276
+ :param rank: 当前节点所属rank
277
+ :param api_name: 当前节点的api名称
278
+ :return:
279
+ """
280
+ communications_type = CommunicationType.SEND_RECEIVE.value
281
+ config_info = self.config.get(api_name)
282
+ if config_info:
283
+ # 此时为一对多通信或多对一通信
284
+ source_rank = self._get_target_rank(node, rank, config_info[1])
285
+ if source_rank is None or str(source_rank) != str(rank):
286
+ return
287
+ communications_type = config_info[2]
288
+ group_ranks, group_id = self._get_group_info(node, rank)
289
+ if not group_ranks or not group_id:
290
+ return
291
+ unique_group_id = self.group_node_mapping.get(rank, {}).get(node.id, '')
292
+ matched_distributed = {'communications_type': communications_type}
293
+ nodes_info = {}
294
+ for target_rank in group_ranks:
295
+ if str(target_rank) == str(rank):
296
+ continue
297
+ target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank)
298
+ if not target_node:
299
+ continue
300
+ _, target_group_id = self._get_group_info(target_node, target_rank)
301
+ if not target_group_id:
302
+ continue
303
+ if group_id != target_group_id:
304
+ logger.warning(
305
+ f'{node.id} of rank{rank} is expected to communicate with {target_node.id} of rank{target_rank}'
306
+ f', but the data shows that the group id of the two nodes are different, '
307
+ f'cannot match distributed node')
308
+ continue
309
+ # 给当前通信节点添加matched_distributed字段信息
310
+ index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
311
+ else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
312
+ nodes_info[target_rank] = [str(index), target_node.id]
313
+ if config_info:
314
+ # 给匹配上的目标节点也添加matched_distributed字段信息
315
+ self._add_node_matched_distributed(target_node, node, api_name, rank, True)
316
+ if nodes_info:
317
+ matched_distributed['nodes_info'] = nodes_info
318
+ node.matched_distributed = matched_distributed