mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -22,7 +22,8 @@ from msprobe.core.common.file_utils import (check_file_type, create_directory, F
22
22
  from msprobe.core.common.const import FileCheckConst, Const
23
23
  from msprobe.core.common.utils import CompareException, get_dump_mode
24
24
  from msprobe.visualization.compare.graph_comparator import GraphComparator
25
- from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs
25
+ from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs, load_parallel_param, \
26
+ sort_rank_number_strings, check_whether_parallel_merge, validate_parallel_param, get_step_or_rank_int
26
27
  from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo
27
28
  from msprobe.core.common.log import logger
28
29
  from msprobe.visualization.graph.node_colors import NodeColors
@@ -30,8 +31,12 @@ from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_map
30
31
  from msprobe.core.compare.utils import check_and_return_dir_contents
31
32
  from msprobe.core.common.utils import detect_framework_by_dump_json
32
33
  from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
34
+ from msprobe.visualization.builder.graph_merger import GraphMerger
35
+ from msprobe.visualization.db_utils import post_process_db
33
36
 
34
37
  current_time = time.strftime("%Y%m%d%H%M%S")
38
+ build_output_db_name = f'build_{current_time}.vis.db'
39
+ compare_output_db_name = f'compare_{current_time}.vis.db'
35
40
 
36
41
 
37
42
  def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args):
@@ -83,32 +88,32 @@ def _export_compare_graph_result(args, result):
83
88
  graphs = [result.graph_n, result.graph_b]
84
89
  graph_comparator = result.graph_comparator
85
90
  micro_steps = result.micro_steps
86
- output_file_name = result.output_file_name
87
- if not output_file_name:
88
- output_file_name = f'compare_{current_time}.vis'
89
- logger.info(f'Start exporting compare graph result, file name: {output_file_name}...')
90
- output_path = os.path.join(args.output_path, output_file_name)
91
+ logger.info(f'Start exporting compare graph result, file name: {compare_output_db_name}...')
92
+ output_db_path = os.path.join(args.output_path, compare_output_db_name)
91
93
  task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
92
94
  export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
93
95
  NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
94
- args.overflow_check, graph_comparator.ma.compare_mode)
96
+ args.overflow_check, graph_comparator.ma.compare_mode, result.step, result.rank,
97
+ args.step_list if hasattr(args, 'step_list') else [0],
98
+ args.rank_list if hasattr(args, 'rank_list') else [0])
95
99
  try:
96
- GraphBuilder.to_json(output_path, export_config)
97
- logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_path}')
100
+ GraphBuilder.to_db(output_db_path, export_config)
101
+ logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_db_path}')
98
102
  return ''
99
103
  except RuntimeError as e:
100
- logger.error(f'Failed to export compare graph result, file: {output_file_name}, error: {e}')
101
- return output_file_name
104
+ logger.error(f'Failed to export compare graph result, file: {compare_output_db_name}, error: {e}')
105
+ return compare_output_db_name
102
106
 
103
107
 
104
- def _build_graph_info(dump_path, args):
108
+ def _build_graph_info(dump_path, args, graph=None):
105
109
  construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
106
110
  FileCheckConst.READ_ABLE).common_check()
107
111
  data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
108
112
  FileCheckConst.READ_ABLE).common_check()
109
113
  stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
110
114
  FileCheckConst.READ_ABLE).common_check()
111
- graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack)
115
+ if not graph:
116
+ graph = GraphBuilder.build(construct_path, data_path, stack_path)
112
117
  return GraphInfo(graph, construct_path, data_path, stack_path)
113
118
 
114
119
 
@@ -134,20 +139,14 @@ def _run_build_graph_compare(input_param, args, nr, br):
134
139
  def _run_build_graph_single(dump_ranks_path, rank, step, args):
135
140
  logger.info(f'Start building graph for {rank}...')
136
141
  dump_path = os.path.join(dump_ranks_path, rank)
137
- output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
138
142
  result = _build_graph_result(dump_path, args)
139
- result.output_file_name = output_file_name
140
143
  if rank != Const.RANK:
141
- try:
142
- result.rank = int(rank.replace(Const.RANK, ""))
143
- except Exception as e:
144
- logger.error('The folder name format is incorrect, expected rank+number.')
145
- raise CompareException(CompareException.INVALID_PATH_ERROR) from e
144
+ result.rank = get_step_or_rank_int(rank, True)
146
145
  logger.info(f'Building graph for step: {step}, rank: {rank} finished.')
147
146
  return result
148
147
 
149
148
 
150
- def _run_graph_compare(graph_task_info, input_param, args, output_file_name):
149
+ def _run_graph_compare(graph_task_info, input_param, args):
151
150
  logger.info(f'Start comparing data for {graph_task_info.npu_rank}...')
152
151
  graph_n = graph_task_info.graph_info_n
153
152
  graph_b = graph_task_info.graph_info_b
@@ -159,13 +158,8 @@ def _run_graph_compare(graph_task_info, input_param, args, output_file_name):
159
158
  graph_n.graph.overflow_check()
160
159
  graph_b.graph.overflow_check()
161
160
  graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
162
- graph_result.output_file_name = output_file_name
163
161
  if nr != Const.RANK:
164
- try:
165
- graph_result.rank = int(nr.replace(Const.RANK, ""))
166
- except Exception as e:
167
- logger.error('The folder name format is incorrect, expected rank+number.')
168
- raise CompareException(CompareException.INVALID_PATH_ERROR) from e
162
+ graph_result.rank = get_step_or_rank_int(nr, True)
169
163
  logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.')
170
164
  return graph_result
171
165
 
@@ -175,19 +169,18 @@ def _export_build_graph_result(args, result):
175
169
  graph = result.graph
176
170
  micro_steps = result.micro_steps
177
171
  overflow_check = args.overflow_check
178
- output_file_name = result.output_file_name
179
- if not output_file_name:
180
- output_file_name = f'build_{current_time}.vis'
181
- logger.info(f'Start exporting graph for {output_file_name}...')
182
- output_path = os.path.join(out_path, output_file_name)
172
+ logger.info(f'Start exporting graph for {build_output_db_name}...')
173
+ output_db_path = os.path.join(out_path, build_output_db_name)
174
+ config = GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check, rank=result.rank,
175
+ step=result.step, rank_list=args.rank_list if hasattr(args, 'rank_list') else [0],
176
+ step_list=args.step_list if hasattr(args, 'step_list') else [0])
183
177
  try:
184
- GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps,
185
- overflow_check=overflow_check))
186
- logger.info(f'Model graph exported successfully, the result file is saved in {output_path}')
178
+ GraphBuilder.to_db(output_db_path, config)
179
+ logger.info(f'Model graph exported successfully, the result file is saved in {output_db_path}')
187
180
  return None
188
181
  except RuntimeError as e:
189
- logger.error(f'Failed to export model graph, file: {output_file_name}, error: {e}')
190
- return output_file_name
182
+ logger.error(f'Failed to export model graph, file: {build_output_db_name}, error: {e}')
183
+ return build_output_db_name
191
184
 
192
185
 
193
186
  def is_real_data_compare(input_param, npu_ranks, bench_ranks):
@@ -205,9 +198,9 @@ def is_real_data_compare(input_param, npu_ranks, bench_ranks):
205
198
  return has_real_data
206
199
 
207
200
 
208
- def _mp_compare(input_param, serializable_args, output_file_name, nr, br):
201
+ def _mp_compare(input_param, serializable_args, nr, br):
209
202
  graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br)
210
- return _run_graph_compare(graph_task_info, input_param, serializable_args, output_file_name)
203
+ return _run_graph_compare(graph_task_info, input_param, serializable_args)
211
204
 
212
205
 
213
206
  def _compare_graph_ranks(input_param, args, step=None):
@@ -223,6 +216,8 @@ def _compare_graph_ranks(input_param, args, step=None):
223
216
  # 暂存所有rank的graph,用于匹配rank间的分布式节点
224
217
  compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call)
225
218
 
219
+ serializable_args.rank_list = [result.rank for result in compare_graph_results]
220
+
226
221
  # 匹配rank间的分布式节点
227
222
  if len(compare_graph_results) > 1:
228
223
  DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
@@ -258,27 +253,28 @@ def _get_compare_graph_results(input_param, serializable_args, step, pool, err_c
258
253
  for nr, br in zip(npu_ranks, bench_ranks):
259
254
  input_param['npu_path'] = os.path.join(dump_rank_n, nr)
260
255
  input_param['bench_path'] = os.path.join(dump_rank_b, br)
261
- output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
256
+ build_key = f'{step}_{nr}' if step else f'{nr}'
262
257
  input_param_copy = deepcopy(input_param)
263
- mp_task_dict[output_file_name] = pool.apply_async(_run_build_graph_compare,
264
- args=(input_param_copy, serializable_args, nr, br),
265
- error_callback=err_call)
258
+ mp_task_dict[build_key] = pool.apply_async(_run_build_graph_compare,
259
+ args=(input_param_copy, serializable_args, nr, br),
260
+ error_callback=err_call)
266
261
 
267
262
  mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()}
268
- for output_file_name, mp_res in mp_res_dict.items():
269
- compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args, output_file_name))
263
+ for mp_res in mp_res_dict.values():
264
+ compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args))
270
265
  else:
271
266
  compare_graph_tasks = []
272
267
  for nr, br in zip(npu_ranks, bench_ranks):
273
268
  input_param['npu_path'] = os.path.join(dump_rank_n, nr)
274
269
  input_param['bench_path'] = os.path.join(dump_rank_b, br)
275
- output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
276
270
  input_param_copy = deepcopy(input_param)
277
271
  compare_graph_tasks.append(pool.apply_async(_mp_compare,
278
- args=(input_param_copy, serializable_args, output_file_name, nr,
279
- br),
272
+ args=(input_param_copy, serializable_args, nr, br),
280
273
  error_callback=err_call))
281
274
  compare_graph_results = [task.get() for task in compare_graph_tasks]
275
+ if step is not None:
276
+ for result in compare_graph_results:
277
+ result.step = get_step_or_rank_int(step)
282
278
  return compare_graph_results
283
279
 
284
280
 
@@ -293,16 +289,19 @@ def _compare_graph_steps(input_param, args):
293
289
  logger.error('The number of steps in the two runs is different. Unable to match the steps.')
294
290
  raise CompareException(CompareException.INVALID_PATH_ERROR)
295
291
 
292
+ args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps])
293
+
296
294
  for folder_step in npu_steps:
297
295
  logger.info(f'Start processing data for {folder_step}...')
298
296
  input_param['npu_path'] = os.path.join(dump_step_n, folder_step)
299
297
  input_param['bench_path'] = os.path.join(dump_step_b, folder_step)
300
298
 
301
- _compare_graph_ranks(input_param, args, step=folder_step)
299
+ _compare_graph_ranks(input_param, args, step=folder_step) if not args.parallel_merge \
300
+ else _compare_graph_ranks_parallel(input_param, args, step=folder_step)
302
301
 
303
302
 
304
303
  def _build_graph_ranks(dump_ranks_path, args, step=None):
305
- ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
304
+ ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
306
305
  serializable_args = SerializableArgs(args)
307
306
  with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
308
307
  def err_call(err):
@@ -319,12 +318,21 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
319
318
  error_callback=err_call))
320
319
  build_graph_results = [task.get() for task in build_graph_tasks]
321
320
 
322
- if len(build_graph_results) > 1:
321
+ if step is not None:
322
+ for result in build_graph_results:
323
+ result.step = get_step_or_rank_int(step)
324
+
325
+ if args.parallel_params:
326
+ validate_parallel_param(args.parallel_params[0], dump_ranks_path)
327
+ build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph()
328
+
329
+ if len(build_graph_results) > 1 and not args.parallel_merge:
323
330
  DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
324
331
  args.overflow_check).distributed_match()
325
332
 
326
333
  create_directory(args.output_path)
327
334
  export_build_graph_tasks = []
335
+ serializable_args.rank_list = [result.rank for result in build_graph_results]
328
336
  for result in build_graph_results:
329
337
  export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result,
330
338
  args=(serializable_args, result),
@@ -337,15 +345,84 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
337
345
  logger.info(f'Successfully exported build graph results.')
338
346
 
339
347
 
340
-
341
348
  def _build_graph_steps(dump_steps_path, args):
342
349
  steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
350
+ args.step_list = sorted([get_step_or_rank_int(step) for step in steps])
351
+
343
352
  for step in steps:
344
353
  logger.info(f'Start processing data for {step}...')
345
354
  dump_ranks_path = os.path.join(dump_steps_path, step)
346
355
  _build_graph_ranks(dump_ranks_path, args, step)
347
356
 
348
357
 
358
+ def _compare_and_export_graph(graph_task_info, input_param, args):
359
+ result = _run_graph_compare(graph_task_info, input_param, args)
360
+ return _export_compare_graph_result(args, result)
361
+
362
+
363
+ def _compare_graph_ranks_parallel(input_param, args, step=None):
364
+ args.fuzzy_match = True
365
+ npu_path = input_param.get('npu_path')
366
+ bench_path = input_param.get('bench_path')
367
+ ranks_n = sort_rank_number_strings(check_and_return_dir_contents(npu_path, Const.RANK))
368
+ ranks_b = sort_rank_number_strings(check_and_return_dir_contents(bench_path, Const.RANK))
369
+ parallel_params = load_parallel_param(input_param)
370
+ if len(parallel_params) != 2:
371
+ raise RuntimeError('Parallel params error in compare graph!')
372
+ validate_parallel_param(parallel_params[0], npu_path)
373
+ validate_parallel_param(parallel_params[1], bench_path, '[Bench]')
374
+ serializable_args = SerializableArgs(args)
375
+
376
+ with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
377
+ def err_call(err):
378
+ logger.error(f'Error occurred while comparing graph ranks: {err}')
379
+ try:
380
+ pool.close()
381
+ except OSError as e:
382
+ logger.error(f'Error occurred while terminating the pool: {e}')
383
+
384
+ # 1.并行构图
385
+ build_graph_tasks_n = []
386
+ build_graph_tasks_b = []
387
+ for rank in ranks_n:
388
+ build_graph_tasks_n.append(pool.apply_async(_run_build_graph_single,
389
+ args=(npu_path, rank, step, serializable_args),
390
+ error_callback=err_call))
391
+ for rank in ranks_b:
392
+ build_graph_tasks_b.append(pool.apply_async(_run_build_graph_single,
393
+ args=(bench_path, rank, step, serializable_args),
394
+ error_callback=err_call))
395
+ graph_results_n = [task.get() for task in build_graph_tasks_n]
396
+ graph_results_b = [task.get() for task in build_graph_tasks_b]
397
+
398
+ # 2.图合并
399
+ build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0]).merge_graph()
400
+ build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True).merge_graph()
401
+ if len(build_graph_results_n) != len(build_graph_results_b):
402
+ raise RuntimeError(f'Parallel merge failed because the dp of npu: {len(build_graph_results_n)} '
403
+ f'is inconsistent with that of bench: {len(build_graph_results_b)}!')
404
+ serializable_args.rank_list = [result.rank for result in build_graph_results_n]
405
+ # 3.并行图比对和输出
406
+ export_res_task_list = []
407
+ create_directory(args.output_path)
408
+ for i, result_n in enumerate(build_graph_results_n):
409
+ graph_n = result_n.graph
410
+ graph_b = build_graph_results_b[i].graph
411
+ graph_task_info = BuildGraphTaskInfo(
412
+ _build_graph_info(os.path.join(npu_path, f'rank{graph_n.root.rank}'), args, graph_n),
413
+ _build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b),
414
+ f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time)
415
+ export_res_task_list.append(pool.apply_async(_compare_and_export_graph,
416
+ args=(graph_task_info, input_param, serializable_args),
417
+ error_callback=err_call))
418
+ export_res_list = [res.get() for res in export_res_task_list]
419
+ if any(export_res_list):
420
+ failed_names = list(filter(lambda x: x, export_res_list))
421
+ logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.')
422
+ else:
423
+ logger.info('Successfully exported compare graph results.')
424
+
425
+
349
426
  def _graph_service_parser(parser):
350
427
  parser.add_argument("-i", "--input_path", dest="input_path", type=str,
351
428
  help="<Required> The compare input path, a dict json.", required=True)
@@ -357,19 +434,20 @@ def _graph_service_parser(parser):
357
434
  help="<Optional> whether open overflow_check for graph.", required=False)
358
435
  parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
359
436
  help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
360
- parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true",
361
- help="<Optional> Whether to use complete stack information.", required=False)
362
437
 
363
438
 
364
439
  def _graph_service_command(args):
365
440
  input_param = load_json(args.input_path)
366
441
  npu_path = input_param.get("npu_path")
367
442
  bench_path = input_param.get("bench_path")
443
+ args.parallel_merge = check_whether_parallel_merge(input_param)
444
+ args.parallel_params = load_parallel_param(input_param) if args.parallel_merge else None
368
445
  check_file_or_directory_path(npu_path, isdir=True)
369
446
  if bench_path:
370
447
  check_file_or_directory_path(bench_path, isdir=True)
371
448
  if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
372
449
  content = check_directory_content(npu_path)
450
+ output_db_path = os.path.join(args.output_path, build_output_db_name)
373
451
  if content == GraphConst.RANKS:
374
452
  _build_graph_ranks(npu_path, args)
375
453
  elif content == GraphConst.STEPS:
@@ -383,10 +461,14 @@ def _graph_service_command(args):
383
461
  elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
384
462
  content_n = check_directory_content(npu_path)
385
463
  content_b = check_directory_content(bench_path)
464
+ output_db_path = os.path.join(args.output_path, compare_output_db_name)
386
465
  if content_n != content_b:
387
466
  raise ValueError('The directory structures of npu_path and bench_path are inconsistent.')
388
467
  if content_n == GraphConst.RANKS:
389
- _compare_graph_ranks(input_param, args)
468
+ if args.parallel_merge:
469
+ _compare_graph_ranks_parallel(input_param, args)
470
+ else:
471
+ _compare_graph_ranks(input_param, args)
390
472
  elif content_n == GraphConst.STEPS:
391
473
  _compare_graph_steps(input_param, args)
392
474
  else:
@@ -398,6 +480,8 @@ def _graph_service_command(args):
398
480
  else:
399
481
  logger.error("The npu_path or bench_path should be a folder.")
400
482
  raise CompareException(CompareException.INVALID_COMPARE_MODE)
483
+ # 所有数据输出db结束后,添加索引,修改权限
484
+ post_process_db(output_db_path)
401
485
 
402
486
 
403
487
  def _pt_graph_service_parser(parser):
@@ -417,18 +501,18 @@ def _ms_graph_service_command(args):
417
501
 
418
502
 
419
503
  class CompareGraphResult:
420
- def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''):
504
+ def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, step=0):
421
505
  self.graph_n = graph_n
422
506
  self.graph_b = graph_b
423
507
  self.graph_comparator = graph_comparator
424
508
  self.micro_steps = micro_steps
425
509
  self.rank = rank
426
- self.output_file_name = output_file_name
510
+ self.step = step
427
511
 
428
512
 
429
513
  class BuildGraphResult:
430
- def __init__(self, graph, micro_steps, rank=0, output_file_name=''):
514
+ def __init__(self, graph, micro_steps=0, rank=0, step=0):
431
515
  self.graph = graph
432
516
  self.micro_steps = micro_steps
433
517
  self.rank = rank
434
- self.output_file_name = output_file_name
518
+ self.step = step
@@ -20,6 +20,8 @@ import pickle
20
20
  from msprobe.core.common.file_utils import FileOpen
21
21
  from msprobe.core.common.const import CompareConst, Const
22
22
  from msprobe.core.common.log import logger
23
+ from msprobe.core.common.exceptions import MsprobeException
24
+ from msprobe.core.compare.utils import check_and_return_dir_contents
23
25
 
24
26
 
25
27
  def load_json_file(file_path):
@@ -57,6 +59,21 @@ def str2float(percentage_str):
57
59
  return 0
58
60
 
59
61
 
62
+ def get_step_or_rank_int(x: str, is_rank=False):
63
+ """
64
+ 获取字符串rank{int}或者step{int}中的int值,如果x=rank或step,返回0
65
+ """
66
+ if x in [Const.RANK, Const.STEP]:
67
+ return 0
68
+ description = Const.RANK if is_rank else Const.STEP
69
+ try:
70
+ x_int = int(x.replace(Const.RANK, "")) if is_rank else int(x.replace(Const.STEP, ""))
71
+ except Exception as e:
72
+ logger.error(f'The folder name format is incorrect, expected {description}+number, such as rank0, step1, etc.')
73
+ raise RuntimeError from e
74
+ return x_int
75
+
76
+
60
77
  def check_directory_content(input_path):
61
78
  """
62
79
  检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件
@@ -102,6 +119,83 @@ def check_directory_content(input_path):
102
119
  "all rank{number} named folders (such as rank0), or all files.")
103
120
 
104
121
 
122
+ def extract_rank_number(rank_str):
123
+ try:
124
+ return int(rank_str[4:])
125
+ except ValueError:
126
+ return 0
127
+
128
+
129
+ def sort_rank_number_strings(rank_number_strings):
130
+ sorted_list = sorted(rank_number_strings, key=extract_rank_number)
131
+ return sorted_list
132
+
133
+
134
+ def check_whether_parallel_merge(input_param):
135
+ parallel_merge = input_param.get("parallel_merge")
136
+ if not isinstance(parallel_merge, dict) or not parallel_merge:
137
+ return False
138
+ if not parallel_merge.get('npu'):
139
+ return False
140
+ return True
141
+
142
+
143
+ def load_parallel_param(input_param):
144
+ parallel_merge = input_param.get("parallel_merge", {})
145
+ config_n = parallel_merge.get('npu', {})
146
+ config_b = parallel_merge.get('bench', {})
147
+ param_n = ParallelParam(config_n.get('rank_size'), config_n.get('tp'), config_n.get('pp'), config_n.get('vpp', 1),
148
+ config_n.get('order', 'tp-cp-ep-dp-pp'))
149
+ param_b = ParallelParam(config_b.get('rank_size'), config_b.get('tp'), config_b.get('pp'), config_b.get('vpp', 1),
150
+ config_b.get('order', 'tp-cp-ep-dp-pp'))
151
+ return (param_n,) if not config_b else (param_n, param_b)
152
+
153
+
154
+ def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
155
+ params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size]
156
+ ranks = check_and_return_dir_contents(dump_path, Const.RANK)
157
+ if len(ranks) != parallel_param.rank_size:
158
+ logger.error(f'{log_prefix} The parallel param "rank_size" error, '
159
+ f'you set {parallel_param.rank_size} but expected {len(ranks)}.')
160
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
161
+ if any(x is None for x in params):
162
+ logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!')
163
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
164
+ if any(x <= 0 for x in params):
165
+ logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!')
166
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
167
+ if parallel_param.tp > parallel_param.rank_size:
168
+ logger.error(f'{log_prefix} The parallel param "tp" must be less than or equal to "rank_size"!')
169
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
170
+ if parallel_param.pp > parallel_param.rank_size:
171
+ logger.error(f'{log_prefix} The parallel param "pp" must be less than or equal to "rank_size"!')
172
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
173
+ if parallel_param.rank_size % parallel_param.tp != 0:
174
+ logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "tp"!')
175
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
176
+ if parallel_param.rank_size % parallel_param.pp != 0:
177
+ logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "pp"!')
178
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
179
+ if parallel_param.tp * parallel_param.pp > parallel_param.rank_size:
180
+ logger.error(f'{log_prefix} The parallel params "tp * pp" must be less than or equal to "rank_size"!')
181
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
182
+ if parallel_param.vpp > 1 and parallel_param.pp < 2:
183
+ logger.error(f'{log_prefix} When configuring the parallel param "vpp", the "pp" param must be greater than 1!')
184
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
185
+ if not isinstance(parallel_param.order, str):
186
+ logger.error(f'{log_prefix} The parallel params "order" must be of string type!')
187
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
188
+
189
+
190
+ class ParallelParam:
191
+ def __init__(self, rank_size, tp, pp, vpp=1, order='tp-cp-ep-dp-pp'):
192
+ self.rank_size = rank_size
193
+ self.tp = tp
194
+ self.pp = pp
195
+ self.vpp = vpp
196
+ self.order = order
197
+
198
+
105
199
  class ToolTip:
106
200
  MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值'
107
201
  MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值'
@@ -147,10 +241,11 @@ class GraphConst:
147
241
  INPUT = '.input.'
148
242
  OUTPUT = '.output.'
149
243
  STR_MAX_LEN = 50
150
- MD5_INDEX_LIST = [CompareConst.RESULT]
151
- REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX
152
- SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX
244
+ MD5_INDEX_LIST = CompareConst.MD5_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
245
+ REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
246
+ SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
153
247
  APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
248
+ APIS_BETWEEN_MODULES_ALL_RANKS = 'Apis_Between_Modules_All_Ranks'
154
249
  NULL = 'null'
155
250
  NONE = 'None'
156
251
  VALUE = 'value'
@@ -184,9 +279,13 @@ class GraphConst:
184
279
  OP = 'op'
185
280
  PEER = 'peer'
186
281
  GROUP_ID = 'group_id'
187
-
282
+
283
+ UNCERTAINTY_THRESHOLD = 1e-6
284
+ REDUCE_OPERATIONS = ['reduce_scatter', 'all_reduce']
285
+
188
286
  IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty',
189
287
  'empty_strided'}
288
+ VPP_CHUNK_0 = '0'
190
289
 
191
290
 
192
291
  def is_serializable(obj):