mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -22,7 +22,8 @@ from msprobe.core.common.file_utils import (check_file_type, create_directory, F
|
|
|
22
22
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
23
23
|
from msprobe.core.common.utils import CompareException, get_dump_mode
|
|
24
24
|
from msprobe.visualization.compare.graph_comparator import GraphComparator
|
|
25
|
-
from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs
|
|
25
|
+
from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs, load_parallel_param, \
|
|
26
|
+
sort_rank_number_strings, check_whether_parallel_merge, validate_parallel_param, get_step_or_rank_int
|
|
26
27
|
from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo
|
|
27
28
|
from msprobe.core.common.log import logger
|
|
28
29
|
from msprobe.visualization.graph.node_colors import NodeColors
|
|
@@ -30,8 +31,12 @@ from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_map
|
|
|
30
31
|
from msprobe.core.compare.utils import check_and_return_dir_contents
|
|
31
32
|
from msprobe.core.common.utils import detect_framework_by_dump_json
|
|
32
33
|
from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer
|
|
34
|
+
from msprobe.visualization.builder.graph_merger import GraphMerger
|
|
35
|
+
from msprobe.visualization.db_utils import post_process_db
|
|
33
36
|
|
|
34
37
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
38
|
+
build_output_db_name = f'build_{current_time}.vis.db'
|
|
39
|
+
compare_output_db_name = f'compare_{current_time}.vis.db'
|
|
35
40
|
|
|
36
41
|
|
|
37
42
|
def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args):
|
|
@@ -83,32 +88,32 @@ def _export_compare_graph_result(args, result):
|
|
|
83
88
|
graphs = [result.graph_n, result.graph_b]
|
|
84
89
|
graph_comparator = result.graph_comparator
|
|
85
90
|
micro_steps = result.micro_steps
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
output_file_name = f'compare_{current_time}.vis'
|
|
89
|
-
logger.info(f'Start exporting compare graph result, file name: {output_file_name}...')
|
|
90
|
-
output_path = os.path.join(args.output_path, output_file_name)
|
|
91
|
+
logger.info(f'Start exporting compare graph result, file name: {compare_output_db_name}...')
|
|
92
|
+
output_db_path = os.path.join(args.output_path, compare_output_db_name)
|
|
91
93
|
task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode)
|
|
92
94
|
export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(),
|
|
93
95
|
NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task,
|
|
94
|
-
args.overflow_check, graph_comparator.ma.compare_mode
|
|
96
|
+
args.overflow_check, graph_comparator.ma.compare_mode, result.step, result.rank,
|
|
97
|
+
args.step_list if hasattr(args, 'step_list') else [0],
|
|
98
|
+
args.rank_list if hasattr(args, 'rank_list') else [0])
|
|
95
99
|
try:
|
|
96
|
-
GraphBuilder.
|
|
97
|
-
logger.info(f'Exporting compare graph result successfully, the result file is saved in {
|
|
100
|
+
GraphBuilder.to_db(output_db_path, export_config)
|
|
101
|
+
logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_db_path}')
|
|
98
102
|
return ''
|
|
99
103
|
except RuntimeError as e:
|
|
100
|
-
logger.error(f'Failed to export compare graph result, file: {
|
|
101
|
-
return
|
|
104
|
+
logger.error(f'Failed to export compare graph result, file: {compare_output_db_name}, error: {e}')
|
|
105
|
+
return compare_output_db_name
|
|
102
106
|
|
|
103
107
|
|
|
104
|
-
def _build_graph_info(dump_path, args):
|
|
108
|
+
def _build_graph_info(dump_path, args, graph=None):
|
|
105
109
|
construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE,
|
|
106
110
|
FileCheckConst.READ_ABLE).common_check()
|
|
107
111
|
data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE,
|
|
108
112
|
FileCheckConst.READ_ABLE).common_check()
|
|
109
113
|
stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE,
|
|
110
114
|
FileCheckConst.READ_ABLE).common_check()
|
|
111
|
-
|
|
115
|
+
if not graph:
|
|
116
|
+
graph = GraphBuilder.build(construct_path, data_path, stack_path)
|
|
112
117
|
return GraphInfo(graph, construct_path, data_path, stack_path)
|
|
113
118
|
|
|
114
119
|
|
|
@@ -134,20 +139,14 @@ def _run_build_graph_compare(input_param, args, nr, br):
|
|
|
134
139
|
def _run_build_graph_single(dump_ranks_path, rank, step, args):
|
|
135
140
|
logger.info(f'Start building graph for {rank}...')
|
|
136
141
|
dump_path = os.path.join(dump_ranks_path, rank)
|
|
137
|
-
output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis'
|
|
138
142
|
result = _build_graph_result(dump_path, args)
|
|
139
|
-
result.output_file_name = output_file_name
|
|
140
143
|
if rank != Const.RANK:
|
|
141
|
-
|
|
142
|
-
result.rank = int(rank.replace(Const.RANK, ""))
|
|
143
|
-
except Exception as e:
|
|
144
|
-
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
145
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
144
|
+
result.rank = get_step_or_rank_int(rank, True)
|
|
146
145
|
logger.info(f'Building graph for step: {step}, rank: {rank} finished.')
|
|
147
146
|
return result
|
|
148
147
|
|
|
149
148
|
|
|
150
|
-
def _run_graph_compare(graph_task_info, input_param, args
|
|
149
|
+
def _run_graph_compare(graph_task_info, input_param, args):
|
|
151
150
|
logger.info(f'Start comparing data for {graph_task_info.npu_rank}...')
|
|
152
151
|
graph_n = graph_task_info.graph_info_n
|
|
153
152
|
graph_b = graph_task_info.graph_info_b
|
|
@@ -159,13 +158,8 @@ def _run_graph_compare(graph_task_info, input_param, args, output_file_name):
|
|
|
159
158
|
graph_n.graph.overflow_check()
|
|
160
159
|
graph_b.graph.overflow_check()
|
|
161
160
|
graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps)
|
|
162
|
-
graph_result.output_file_name = output_file_name
|
|
163
161
|
if nr != Const.RANK:
|
|
164
|
-
|
|
165
|
-
graph_result.rank = int(nr.replace(Const.RANK, ""))
|
|
166
|
-
except Exception as e:
|
|
167
|
-
logger.error('The folder name format is incorrect, expected rank+number.')
|
|
168
|
-
raise CompareException(CompareException.INVALID_PATH_ERROR) from e
|
|
162
|
+
graph_result.rank = get_step_or_rank_int(nr, True)
|
|
169
163
|
logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.')
|
|
170
164
|
return graph_result
|
|
171
165
|
|
|
@@ -175,19 +169,18 @@ def _export_build_graph_result(args, result):
|
|
|
175
169
|
graph = result.graph
|
|
176
170
|
micro_steps = result.micro_steps
|
|
177
171
|
overflow_check = args.overflow_check
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
172
|
+
logger.info(f'Start exporting graph for {build_output_db_name}...')
|
|
173
|
+
output_db_path = os.path.join(out_path, build_output_db_name)
|
|
174
|
+
config = GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check, rank=result.rank,
|
|
175
|
+
step=result.step, rank_list=args.rank_list if hasattr(args, 'rank_list') else [0],
|
|
176
|
+
step_list=args.step_list if hasattr(args, 'step_list') else [0])
|
|
183
177
|
try:
|
|
184
|
-
GraphBuilder.
|
|
185
|
-
|
|
186
|
-
logger.info(f'Model graph exported successfully, the result file is saved in {output_path}')
|
|
178
|
+
GraphBuilder.to_db(output_db_path, config)
|
|
179
|
+
logger.info(f'Model graph exported successfully, the result file is saved in {output_db_path}')
|
|
187
180
|
return None
|
|
188
181
|
except RuntimeError as e:
|
|
189
|
-
logger.error(f'Failed to export model graph, file: {
|
|
190
|
-
return
|
|
182
|
+
logger.error(f'Failed to export model graph, file: {build_output_db_name}, error: {e}')
|
|
183
|
+
return build_output_db_name
|
|
191
184
|
|
|
192
185
|
|
|
193
186
|
def is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
@@ -205,9 +198,9 @@ def is_real_data_compare(input_param, npu_ranks, bench_ranks):
|
|
|
205
198
|
return has_real_data
|
|
206
199
|
|
|
207
200
|
|
|
208
|
-
def _mp_compare(input_param, serializable_args,
|
|
201
|
+
def _mp_compare(input_param, serializable_args, nr, br):
|
|
209
202
|
graph_task_info = _run_build_graph_compare(input_param, serializable_args, nr, br)
|
|
210
|
-
return _run_graph_compare(graph_task_info, input_param, serializable_args
|
|
203
|
+
return _run_graph_compare(graph_task_info, input_param, serializable_args)
|
|
211
204
|
|
|
212
205
|
|
|
213
206
|
def _compare_graph_ranks(input_param, args, step=None):
|
|
@@ -223,6 +216,8 @@ def _compare_graph_ranks(input_param, args, step=None):
|
|
|
223
216
|
# 暂存所有rank的graph,用于匹配rank间的分布式节点
|
|
224
217
|
compare_graph_results = _get_compare_graph_results(input_param, serializable_args, step, pool, err_call)
|
|
225
218
|
|
|
219
|
+
serializable_args.rank_list = [result.rank for result in compare_graph_results]
|
|
220
|
+
|
|
226
221
|
# 匹配rank间的分布式节点
|
|
227
222
|
if len(compare_graph_results) > 1:
|
|
228
223
|
DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results},
|
|
@@ -258,27 +253,28 @@ def _get_compare_graph_results(input_param, serializable_args, step, pool, err_c
|
|
|
258
253
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
259
254
|
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
260
255
|
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
261
|
-
|
|
256
|
+
build_key = f'{step}_{nr}' if step else f'{nr}'
|
|
262
257
|
input_param_copy = deepcopy(input_param)
|
|
263
|
-
mp_task_dict[
|
|
264
|
-
|
|
265
|
-
|
|
258
|
+
mp_task_dict[build_key] = pool.apply_async(_run_build_graph_compare,
|
|
259
|
+
args=(input_param_copy, serializable_args, nr, br),
|
|
260
|
+
error_callback=err_call)
|
|
266
261
|
|
|
267
262
|
mp_res_dict = {k: v.get() for k, v in mp_task_dict.items()}
|
|
268
|
-
for
|
|
269
|
-
compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args
|
|
263
|
+
for mp_res in mp_res_dict.values():
|
|
264
|
+
compare_graph_results.append(_run_graph_compare(mp_res, input_param, serializable_args))
|
|
270
265
|
else:
|
|
271
266
|
compare_graph_tasks = []
|
|
272
267
|
for nr, br in zip(npu_ranks, bench_ranks):
|
|
273
268
|
input_param['npu_path'] = os.path.join(dump_rank_n, nr)
|
|
274
269
|
input_param['bench_path'] = os.path.join(dump_rank_b, br)
|
|
275
|
-
output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis'
|
|
276
270
|
input_param_copy = deepcopy(input_param)
|
|
277
271
|
compare_graph_tasks.append(pool.apply_async(_mp_compare,
|
|
278
|
-
args=(input_param_copy, serializable_args,
|
|
279
|
-
br),
|
|
272
|
+
args=(input_param_copy, serializable_args, nr, br),
|
|
280
273
|
error_callback=err_call))
|
|
281
274
|
compare_graph_results = [task.get() for task in compare_graph_tasks]
|
|
275
|
+
if step is not None:
|
|
276
|
+
for result in compare_graph_results:
|
|
277
|
+
result.step = get_step_or_rank_int(step)
|
|
282
278
|
return compare_graph_results
|
|
283
279
|
|
|
284
280
|
|
|
@@ -293,16 +289,19 @@ def _compare_graph_steps(input_param, args):
|
|
|
293
289
|
logger.error('The number of steps in the two runs is different. Unable to match the steps.')
|
|
294
290
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
295
291
|
|
|
292
|
+
args.step_list = sorted([get_step_or_rank_int(step) for step in npu_steps])
|
|
293
|
+
|
|
296
294
|
for folder_step in npu_steps:
|
|
297
295
|
logger.info(f'Start processing data for {folder_step}...')
|
|
298
296
|
input_param['npu_path'] = os.path.join(dump_step_n, folder_step)
|
|
299
297
|
input_param['bench_path'] = os.path.join(dump_step_b, folder_step)
|
|
300
298
|
|
|
301
|
-
_compare_graph_ranks(input_param, args, step=folder_step)
|
|
299
|
+
_compare_graph_ranks(input_param, args, step=folder_step) if not args.parallel_merge \
|
|
300
|
+
else _compare_graph_ranks_parallel(input_param, args, step=folder_step)
|
|
302
301
|
|
|
303
302
|
|
|
304
303
|
def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
305
|
-
ranks =
|
|
304
|
+
ranks = sort_rank_number_strings(check_and_return_dir_contents(dump_ranks_path, Const.RANK))
|
|
306
305
|
serializable_args = SerializableArgs(args)
|
|
307
306
|
with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
|
|
308
307
|
def err_call(err):
|
|
@@ -319,12 +318,21 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
|
319
318
|
error_callback=err_call))
|
|
320
319
|
build_graph_results = [task.get() for task in build_graph_tasks]
|
|
321
320
|
|
|
322
|
-
if
|
|
321
|
+
if step is not None:
|
|
322
|
+
for result in build_graph_results:
|
|
323
|
+
result.step = get_step_or_rank_int(step)
|
|
324
|
+
|
|
325
|
+
if args.parallel_params:
|
|
326
|
+
validate_parallel_param(args.parallel_params[0], dump_ranks_path)
|
|
327
|
+
build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph()
|
|
328
|
+
|
|
329
|
+
if len(build_graph_results) > 1 and not args.parallel_merge:
|
|
323
330
|
DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results},
|
|
324
331
|
args.overflow_check).distributed_match()
|
|
325
332
|
|
|
326
333
|
create_directory(args.output_path)
|
|
327
334
|
export_build_graph_tasks = []
|
|
335
|
+
serializable_args.rank_list = [result.rank for result in build_graph_results]
|
|
328
336
|
for result in build_graph_results:
|
|
329
337
|
export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result,
|
|
330
338
|
args=(serializable_args, result),
|
|
@@ -337,15 +345,84 @@ def _build_graph_ranks(dump_ranks_path, args, step=None):
|
|
|
337
345
|
logger.info(f'Successfully exported build graph results.')
|
|
338
346
|
|
|
339
347
|
|
|
340
|
-
|
|
341
348
|
def _build_graph_steps(dump_steps_path, args):
|
|
342
349
|
steps = sorted(check_and_return_dir_contents(dump_steps_path, Const.STEP))
|
|
350
|
+
args.step_list = sorted([get_step_or_rank_int(step) for step in steps])
|
|
351
|
+
|
|
343
352
|
for step in steps:
|
|
344
353
|
logger.info(f'Start processing data for {step}...')
|
|
345
354
|
dump_ranks_path = os.path.join(dump_steps_path, step)
|
|
346
355
|
_build_graph_ranks(dump_ranks_path, args, step)
|
|
347
356
|
|
|
348
357
|
|
|
358
|
+
def _compare_and_export_graph(graph_task_info, input_param, args):
|
|
359
|
+
result = _run_graph_compare(graph_task_info, input_param, args)
|
|
360
|
+
return _export_compare_graph_result(args, result)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _compare_graph_ranks_parallel(input_param, args, step=None):
|
|
364
|
+
args.fuzzy_match = True
|
|
365
|
+
npu_path = input_param.get('npu_path')
|
|
366
|
+
bench_path = input_param.get('bench_path')
|
|
367
|
+
ranks_n = sort_rank_number_strings(check_and_return_dir_contents(npu_path, Const.RANK))
|
|
368
|
+
ranks_b = sort_rank_number_strings(check_and_return_dir_contents(bench_path, Const.RANK))
|
|
369
|
+
parallel_params = load_parallel_param(input_param)
|
|
370
|
+
if len(parallel_params) != 2:
|
|
371
|
+
raise RuntimeError('Parallel params error in compare graph!')
|
|
372
|
+
validate_parallel_param(parallel_params[0], npu_path)
|
|
373
|
+
validate_parallel_param(parallel_params[1], bench_path, '[Bench]')
|
|
374
|
+
serializable_args = SerializableArgs(args)
|
|
375
|
+
|
|
376
|
+
with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool:
|
|
377
|
+
def err_call(err):
|
|
378
|
+
logger.error(f'Error occurred while comparing graph ranks: {err}')
|
|
379
|
+
try:
|
|
380
|
+
pool.close()
|
|
381
|
+
except OSError as e:
|
|
382
|
+
logger.error(f'Error occurred while terminating the pool: {e}')
|
|
383
|
+
|
|
384
|
+
# 1.并行构图
|
|
385
|
+
build_graph_tasks_n = []
|
|
386
|
+
build_graph_tasks_b = []
|
|
387
|
+
for rank in ranks_n:
|
|
388
|
+
build_graph_tasks_n.append(pool.apply_async(_run_build_graph_single,
|
|
389
|
+
args=(npu_path, rank, step, serializable_args),
|
|
390
|
+
error_callback=err_call))
|
|
391
|
+
for rank in ranks_b:
|
|
392
|
+
build_graph_tasks_b.append(pool.apply_async(_run_build_graph_single,
|
|
393
|
+
args=(bench_path, rank, step, serializable_args),
|
|
394
|
+
error_callback=err_call))
|
|
395
|
+
graph_results_n = [task.get() for task in build_graph_tasks_n]
|
|
396
|
+
graph_results_b = [task.get() for task in build_graph_tasks_b]
|
|
397
|
+
|
|
398
|
+
# 2.图合并
|
|
399
|
+
build_graph_results_n = GraphMerger(graph_results_n, parallel_params[0]).merge_graph()
|
|
400
|
+
build_graph_results_b = GraphMerger(graph_results_b, parallel_params[1], True).merge_graph()
|
|
401
|
+
if len(build_graph_results_n) != len(build_graph_results_b):
|
|
402
|
+
raise RuntimeError(f'Parallel merge failed because the dp of npu: {len(build_graph_results_n)} '
|
|
403
|
+
f'is inconsistent with that of bench: {len(build_graph_results_b)}!')
|
|
404
|
+
serializable_args.rank_list = [result.rank for result in build_graph_results_n]
|
|
405
|
+
# 3.并行图比对和输出
|
|
406
|
+
export_res_task_list = []
|
|
407
|
+
create_directory(args.output_path)
|
|
408
|
+
for i, result_n in enumerate(build_graph_results_n):
|
|
409
|
+
graph_n = result_n.graph
|
|
410
|
+
graph_b = build_graph_results_b[i].graph
|
|
411
|
+
graph_task_info = BuildGraphTaskInfo(
|
|
412
|
+
_build_graph_info(os.path.join(npu_path, f'rank{graph_n.root.rank}'), args, graph_n),
|
|
413
|
+
_build_graph_info(os.path.join(bench_path, f'rank{graph_b.root.rank}'), args, graph_b),
|
|
414
|
+
f'rank{graph_n.root.rank}', f'rank{graph_b.root.rank}', current_time)
|
|
415
|
+
export_res_task_list.append(pool.apply_async(_compare_and_export_graph,
|
|
416
|
+
args=(graph_task_info, input_param, serializable_args),
|
|
417
|
+
error_callback=err_call))
|
|
418
|
+
export_res_list = [res.get() for res in export_res_task_list]
|
|
419
|
+
if any(export_res_list):
|
|
420
|
+
failed_names = list(filter(lambda x: x, export_res_list))
|
|
421
|
+
logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.')
|
|
422
|
+
else:
|
|
423
|
+
logger.info('Successfully exported compare graph results.')
|
|
424
|
+
|
|
425
|
+
|
|
349
426
|
def _graph_service_parser(parser):
|
|
350
427
|
parser.add_argument("-i", "--input_path", dest="input_path", type=str,
|
|
351
428
|
help="<Required> The compare input path, a dict json.", required=True)
|
|
@@ -357,19 +434,20 @@ def _graph_service_parser(parser):
|
|
|
357
434
|
help="<Optional> whether open overflow_check for graph.", required=False)
|
|
358
435
|
parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true",
|
|
359
436
|
help="<Optional> Whether to perform a fuzzy match on the api name.", required=False)
|
|
360
|
-
parser.add_argument("-cs", "--complete_stack", dest="complete_stack", action="store_true",
|
|
361
|
-
help="<Optional> Whether to use complete stack information.", required=False)
|
|
362
437
|
|
|
363
438
|
|
|
364
439
|
def _graph_service_command(args):
|
|
365
440
|
input_param = load_json(args.input_path)
|
|
366
441
|
npu_path = input_param.get("npu_path")
|
|
367
442
|
bench_path = input_param.get("bench_path")
|
|
443
|
+
args.parallel_merge = check_whether_parallel_merge(input_param)
|
|
444
|
+
args.parallel_params = load_parallel_param(input_param) if args.parallel_merge else None
|
|
368
445
|
check_file_or_directory_path(npu_path, isdir=True)
|
|
369
446
|
if bench_path:
|
|
370
447
|
check_file_or_directory_path(bench_path, isdir=True)
|
|
371
448
|
if check_file_type(npu_path) == FileCheckConst.DIR and not bench_path:
|
|
372
449
|
content = check_directory_content(npu_path)
|
|
450
|
+
output_db_path = os.path.join(args.output_path, build_output_db_name)
|
|
373
451
|
if content == GraphConst.RANKS:
|
|
374
452
|
_build_graph_ranks(npu_path, args)
|
|
375
453
|
elif content == GraphConst.STEPS:
|
|
@@ -383,10 +461,14 @@ def _graph_service_command(args):
|
|
|
383
461
|
elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR:
|
|
384
462
|
content_n = check_directory_content(npu_path)
|
|
385
463
|
content_b = check_directory_content(bench_path)
|
|
464
|
+
output_db_path = os.path.join(args.output_path, compare_output_db_name)
|
|
386
465
|
if content_n != content_b:
|
|
387
466
|
raise ValueError('The directory structures of npu_path and bench_path are inconsistent.')
|
|
388
467
|
if content_n == GraphConst.RANKS:
|
|
389
|
-
|
|
468
|
+
if args.parallel_merge:
|
|
469
|
+
_compare_graph_ranks_parallel(input_param, args)
|
|
470
|
+
else:
|
|
471
|
+
_compare_graph_ranks(input_param, args)
|
|
390
472
|
elif content_n == GraphConst.STEPS:
|
|
391
473
|
_compare_graph_steps(input_param, args)
|
|
392
474
|
else:
|
|
@@ -398,6 +480,8 @@ def _graph_service_command(args):
|
|
|
398
480
|
else:
|
|
399
481
|
logger.error("The npu_path or bench_path should be a folder.")
|
|
400
482
|
raise CompareException(CompareException.INVALID_COMPARE_MODE)
|
|
483
|
+
# 所有数据输出db结束后,添加索引,修改权限
|
|
484
|
+
post_process_db(output_db_path)
|
|
401
485
|
|
|
402
486
|
|
|
403
487
|
def _pt_graph_service_parser(parser):
|
|
@@ -417,18 +501,18 @@ def _ms_graph_service_command(args):
|
|
|
417
501
|
|
|
418
502
|
|
|
419
503
|
class CompareGraphResult:
|
|
420
|
-
def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0,
|
|
504
|
+
def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, step=0):
|
|
421
505
|
self.graph_n = graph_n
|
|
422
506
|
self.graph_b = graph_b
|
|
423
507
|
self.graph_comparator = graph_comparator
|
|
424
508
|
self.micro_steps = micro_steps
|
|
425
509
|
self.rank = rank
|
|
426
|
-
self.
|
|
510
|
+
self.step = step
|
|
427
511
|
|
|
428
512
|
|
|
429
513
|
class BuildGraphResult:
|
|
430
|
-
def __init__(self, graph, micro_steps, rank=0,
|
|
514
|
+
def __init__(self, graph, micro_steps=0, rank=0, step=0):
|
|
431
515
|
self.graph = graph
|
|
432
516
|
self.micro_steps = micro_steps
|
|
433
517
|
self.rank = rank
|
|
434
|
-
self.
|
|
518
|
+
self.step = step
|
msprobe/visualization/utils.py
CHANGED
|
@@ -20,6 +20,8 @@ import pickle
|
|
|
20
20
|
from msprobe.core.common.file_utils import FileOpen
|
|
21
21
|
from msprobe.core.common.const import CompareConst, Const
|
|
22
22
|
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
24
|
+
from msprobe.core.compare.utils import check_and_return_dir_contents
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
def load_json_file(file_path):
|
|
@@ -57,6 +59,21 @@ def str2float(percentage_str):
|
|
|
57
59
|
return 0
|
|
58
60
|
|
|
59
61
|
|
|
62
|
+
def get_step_or_rank_int(x: str, is_rank=False):
|
|
63
|
+
"""
|
|
64
|
+
获取字符串rank{int}或者step{int}中的int值,如果x=rank或step,返回0
|
|
65
|
+
"""
|
|
66
|
+
if x in [Const.RANK, Const.STEP]:
|
|
67
|
+
return 0
|
|
68
|
+
description = Const.RANK if is_rank else Const.STEP
|
|
69
|
+
try:
|
|
70
|
+
x_int = int(x.replace(Const.RANK, "")) if is_rank else int(x.replace(Const.STEP, ""))
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f'The folder name format is incorrect, expected {description}+number, such as rank0, step1, etc.')
|
|
73
|
+
raise RuntimeError from e
|
|
74
|
+
return x_int
|
|
75
|
+
|
|
76
|
+
|
|
60
77
|
def check_directory_content(input_path):
|
|
61
78
|
"""
|
|
62
79
|
检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件
|
|
@@ -102,6 +119,83 @@ def check_directory_content(input_path):
|
|
|
102
119
|
"all rank{number} named folders (such as rank0), or all files.")
|
|
103
120
|
|
|
104
121
|
|
|
122
|
+
def extract_rank_number(rank_str):
|
|
123
|
+
try:
|
|
124
|
+
return int(rank_str[4:])
|
|
125
|
+
except ValueError:
|
|
126
|
+
return 0
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def sort_rank_number_strings(rank_number_strings):
|
|
130
|
+
sorted_list = sorted(rank_number_strings, key=extract_rank_number)
|
|
131
|
+
return sorted_list
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def check_whether_parallel_merge(input_param):
|
|
135
|
+
parallel_merge = input_param.get("parallel_merge")
|
|
136
|
+
if not isinstance(parallel_merge, dict) or not parallel_merge:
|
|
137
|
+
return False
|
|
138
|
+
if not parallel_merge.get('npu'):
|
|
139
|
+
return False
|
|
140
|
+
return True
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def load_parallel_param(input_param):
|
|
144
|
+
parallel_merge = input_param.get("parallel_merge", {})
|
|
145
|
+
config_n = parallel_merge.get('npu', {})
|
|
146
|
+
config_b = parallel_merge.get('bench', {})
|
|
147
|
+
param_n = ParallelParam(config_n.get('rank_size'), config_n.get('tp'), config_n.get('pp'), config_n.get('vpp', 1),
|
|
148
|
+
config_n.get('order', 'tp-cp-ep-dp-pp'))
|
|
149
|
+
param_b = ParallelParam(config_b.get('rank_size'), config_b.get('tp'), config_b.get('pp'), config_b.get('vpp', 1),
|
|
150
|
+
config_b.get('order', 'tp-cp-ep-dp-pp'))
|
|
151
|
+
return (param_n,) if not config_b else (param_n, param_b)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def validate_parallel_param(parallel_param, dump_path, log_prefix='[NPU]'):
|
|
155
|
+
params = [parallel_param.tp, parallel_param.pp, parallel_param.rank_size]
|
|
156
|
+
ranks = check_and_return_dir_contents(dump_path, Const.RANK)
|
|
157
|
+
if len(ranks) != parallel_param.rank_size:
|
|
158
|
+
logger.error(f'{log_prefix} The parallel param "rank_size" error, '
|
|
159
|
+
f'you set {parallel_param.rank_size} but expected {len(ranks)}.')
|
|
160
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
161
|
+
if any(x is None for x in params):
|
|
162
|
+
logger.error(f'{log_prefix} The parallel params "tp/pp/rank_size" must not be null!')
|
|
163
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
164
|
+
if any(x <= 0 for x in params):
|
|
165
|
+
logger.error(f'{log_prefix} The parallel params "tp/pp/vpp/rank_size" must be greater than 0!')
|
|
166
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
167
|
+
if parallel_param.tp > parallel_param.rank_size:
|
|
168
|
+
logger.error(f'{log_prefix} The parallel param "tp" must be less than or equal to "rank_size"!')
|
|
169
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
170
|
+
if parallel_param.pp > parallel_param.rank_size:
|
|
171
|
+
logger.error(f'{log_prefix} The parallel param "pp" must be less than or equal to "rank_size"!')
|
|
172
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
173
|
+
if parallel_param.rank_size % parallel_param.tp != 0:
|
|
174
|
+
logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "tp"!')
|
|
175
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
176
|
+
if parallel_param.rank_size % parallel_param.pp != 0:
|
|
177
|
+
logger.error(f'{log_prefix} The parallel param "rank_size" must be divisible by "pp"!')
|
|
178
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
179
|
+
if parallel_param.tp * parallel_param.pp > parallel_param.rank_size:
|
|
180
|
+
logger.error(f'{log_prefix} The parallel params "tp * pp" must be less than or equal to "rank_size"!')
|
|
181
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
182
|
+
if parallel_param.vpp > 1 and parallel_param.pp < 2:
|
|
183
|
+
logger.error(f'{log_prefix} When configuring the parallel param "vpp", the "pp" param must be greater than 1!')
|
|
184
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
185
|
+
if not isinstance(parallel_param.order, str):
|
|
186
|
+
logger.error(f'{log_prefix} The parallel params "order" must be of string type!')
|
|
187
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class ParallelParam:
|
|
191
|
+
def __init__(self, rank_size, tp, pp, vpp=1, order='tp-cp-ep-dp-pp'):
|
|
192
|
+
self.rank_size = rank_size
|
|
193
|
+
self.tp = tp
|
|
194
|
+
self.pp = pp
|
|
195
|
+
self.vpp = vpp
|
|
196
|
+
self.order = order
|
|
197
|
+
|
|
198
|
+
|
|
105
199
|
class ToolTip:
|
|
106
200
|
MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值'
|
|
107
201
|
MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值'
|
|
@@ -147,10 +241,11 @@ class GraphConst:
|
|
|
147
241
|
INPUT = '.input.'
|
|
148
242
|
OUTPUT = '.output.'
|
|
149
243
|
STR_MAX_LEN = 50
|
|
150
|
-
MD5_INDEX_LIST = [CompareConst.
|
|
151
|
-
REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX
|
|
152
|
-
SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX
|
|
244
|
+
MD5_INDEX_LIST = CompareConst.MD5_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
|
|
245
|
+
REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
|
|
246
|
+
SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX + [CompareConst.REQ_GRAD_CONSIST]
|
|
153
247
|
APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
|
|
248
|
+
APIS_BETWEEN_MODULES_ALL_RANKS = 'Apis_Between_Modules_All_Ranks'
|
|
154
249
|
NULL = 'null'
|
|
155
250
|
NONE = 'None'
|
|
156
251
|
VALUE = 'value'
|
|
@@ -184,9 +279,13 @@ class GraphConst:
|
|
|
184
279
|
OP = 'op'
|
|
185
280
|
PEER = 'peer'
|
|
186
281
|
GROUP_ID = 'group_id'
|
|
187
|
-
|
|
282
|
+
|
|
283
|
+
UNCERTAINTY_THRESHOLD = 1e-6
|
|
284
|
+
REDUCE_OPERATIONS = ['reduce_scatter', 'all_reduce']
|
|
285
|
+
|
|
188
286
|
IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty',
|
|
189
287
|
'empty_strided'}
|
|
288
|
+
VPP_CHUNK_0 = '0'
|
|
190
289
|
|
|
191
290
|
|
|
192
291
|
def is_serializable(obj):
|