mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
- mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +16 -21
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +185 -11
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +4 -0
- msprobe/core/common/utils.py +42 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +381 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +199 -69
- msprobe/core/data_dump/data_collector.py +100 -25
- msprobe/core/data_dump/data_processor/base.py +130 -28
- msprobe/core/data_dump/data_processor/factory.py +8 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
- msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
- msprobe/core/data_dump/json_writer.py +54 -8
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +121 -17
- msprobe/docs/02.config_introduction.md +18 -16
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +107 -58
- msprobe/docs/06.data_dump_MindSpore.md +95 -34
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +310 -220
- msprobe/docs/21.visualization_PyTorch.md +125 -35
- msprobe/docs/22.visualization_MindSpore.md +149 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +525 -0
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +11 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +68 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +28 -2
- msprobe/mindspore/debugger/precision_debugger.py +100 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +13 -3
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +347 -107
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -7
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +116 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +42 -17
- msprobe/pytorch/debugger/precision_debugger.py +56 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +98 -28
- msprobe/pytorch/monitor/csv2tb.py +164 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +543 -318
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +76 -56
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +84 -48
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +264 -115
- msprobe/visualization/builder/graph_builder.py +93 -10
- msprobe/visualization/builder/msprobe_adapter.py +30 -6
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +15 -19
- msprobe/visualization/graph/distributed_analyzer.py +395 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +100 -27
- msprobe/visualization/utils.py +24 -31
- mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
- msprobe/pytorch/functional/module_dump.py +0 -84
- msprobe/pytorch/module_processer.py +0 -150
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,16 +14,24 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import load_json
|
|
20
|
+
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
+
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
17
22
|
from msprobe.visualization.graph.graph import Graph
|
|
18
23
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
19
24
|
from msprobe.visualization.utils import save_json_file, GraphConst
|
|
20
|
-
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
-
from msprobe.core.common.file_utils import load_json
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class GraphBuilder:
|
|
28
|
+
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
29
|
+
forward_pattern = re.compile(r"(\.forward\.)(\d+)$")
|
|
30
|
+
# 匹配以大写字母开头,后接任意字母,并以Template(结尾
|
|
31
|
+
template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
|
|
32
|
+
|
|
25
33
|
@staticmethod
|
|
26
|
-
def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
|
|
34
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
|
|
27
35
|
"""
|
|
28
36
|
GraphBuilder的对外提供的构图方法
|
|
29
37
|
Args:
|
|
@@ -31,11 +39,14 @@ class GraphBuilder:
|
|
|
31
39
|
data_path: dump.json路径
|
|
32
40
|
stack_path: stack.json路径
|
|
33
41
|
model_name: 模型名字,依赖外部输入
|
|
42
|
+
complete_stack: 完整的堆栈信息
|
|
34
43
|
Returns: Graph,代表图的数据结构
|
|
35
44
|
"""
|
|
36
45
|
construct_dict = load_json(construct_path)
|
|
37
46
|
dump_dict = load_json(data_path)
|
|
38
47
|
stack_dict = load_json(stack_path)
|
|
48
|
+
if not complete_stack:
|
|
49
|
+
GraphBuilder._simplify_stack(stack_dict)
|
|
39
50
|
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
40
51
|
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
41
52
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
@@ -61,20 +72,59 @@ class GraphBuilder:
|
|
|
61
72
|
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
62
73
|
if config.task:
|
|
63
74
|
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
75
|
+
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
64
76
|
save_json_file(filename, result)
|
|
65
77
|
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _simplify_stack(stack_dict):
|
|
80
|
+
"""
|
|
81
|
+
精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
|
|
82
|
+
|
|
83
|
+
例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
|
|
84
|
+
保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
|
|
85
|
+
|
|
86
|
+
例如Api Tensor.__iadd__.4.forward,堆栈为:
|
|
87
|
+
"File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
|
|
88
|
+
"File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
|
|
89
|
+
匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
|
|
90
|
+
"""
|
|
91
|
+
module_pattern = re.compile(op_patterns[0])
|
|
92
|
+
for dump_name, stack_list in stack_dict.items():
|
|
93
|
+
if not isinstance(stack_list, list):
|
|
94
|
+
continue
|
|
95
|
+
if module_pattern.match(dump_name):
|
|
96
|
+
parts = dump_name.split(Const.SEP)
|
|
97
|
+
if len(parts) < abs(Const.LAYER_NAME_INDEX):
|
|
98
|
+
continue
|
|
99
|
+
module_name = parts[Const.LAYER_NAME_INDEX]
|
|
100
|
+
for stack in stack_list:
|
|
101
|
+
if re.search(module_name + r'\(', stack):
|
|
102
|
+
stack_list = [stack]
|
|
103
|
+
break
|
|
104
|
+
else:
|
|
105
|
+
for index, stack in enumerate(stack_list):
|
|
106
|
+
if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
|
|
107
|
+
stack_list = [stack_list[index + 1]]
|
|
108
|
+
break
|
|
109
|
+
stack_dict[dump_name] = stack_list
|
|
110
|
+
|
|
66
111
|
@staticmethod
|
|
67
112
|
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
68
113
|
"""
|
|
69
114
|
如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点
|
|
70
115
|
"""
|
|
71
116
|
# 匹配以.backward.后跟一个或多个数字结尾的模式
|
|
72
|
-
backward_pattern
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
117
|
+
if GraphBuilder.backward_pattern.search(subnode_id) and not upnode_id:
|
|
118
|
+
forward_upnode_id = construct_dict.get(GraphBuilder.backward_pattern.sub(r".forward.\2", subnode_id))
|
|
119
|
+
if forward_upnode_id:
|
|
120
|
+
new_upnode_id = GraphBuilder.forward_pattern.sub(r".backward.\2", forward_upnode_id)
|
|
121
|
+
if new_upnode_id in construct_dict:
|
|
122
|
+
return new_upnode_id
|
|
123
|
+
# 匹配以.backward结尾的节点
|
|
124
|
+
if subnode_id.endswith(Const.SEP + Const.BACKWARD) and not upnode_id:
|
|
125
|
+
forward_upnode_id = construct_dict.get(subnode_id.replace(Const.BACKWARD, Const.FORWARD))
|
|
76
126
|
if forward_upnode_id:
|
|
77
|
-
new_upnode_id =
|
|
127
|
+
new_upnode_id = forward_upnode_id.replace(Const.FORWARD, Const.BACKWARD)
|
|
78
128
|
if new_upnode_id in construct_dict:
|
|
79
129
|
return new_upnode_id
|
|
80
130
|
return upnode_id
|
|
@@ -104,11 +154,42 @@ class GraphBuilder:
|
|
|
104
154
|
input_data, output_data = get_input_output(node_data, node.id)
|
|
105
155
|
# 更新数据
|
|
106
156
|
node.set_input_output(input_data, output_data)
|
|
157
|
+
if GraphConst.BATCH_P2P in name:
|
|
158
|
+
GraphBuilder._extract_batch_p2p_info(node, node_data)
|
|
159
|
+
# 反向节点使用对应前向节点的堆栈信息
|
|
160
|
+
# 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
|
|
161
|
+
if (not node_stack_info and
|
|
162
|
+
(GraphBuilder.backward_pattern.search(name) or name.endswith(f'{Const.SEP}{Const.BACKWARD}'))):
|
|
163
|
+
forward_node = graph.get_node(
|
|
164
|
+
# 同名模块全局唯一,无论调用几次堆栈信息都一致,直接使用编号0的同名模块堆栈信息,避免遗漏
|
|
165
|
+
GraphBuilder.backward_pattern.sub(f'{Const.SEP}{Const.FORWARD}{Const.SEP}0', name)) \
|
|
166
|
+
if GraphBuilder.backward_pattern.search(name) \
|
|
167
|
+
else graph.get_node(name.replace(Const.BACKWARD, Const.FORWARD))
|
|
168
|
+
node_stack_info = forward_node.stack_info if forward_node \
|
|
169
|
+
else ['This backward node cannot find the forward node and cannot retrieve stack information.']
|
|
107
170
|
node.stack_info = node_stack_info
|
|
108
171
|
# 添加节点
|
|
109
172
|
node.add_upnode(upnode)
|
|
110
173
|
return node
|
|
111
174
|
|
|
175
|
+
@staticmethod
|
|
176
|
+
def _is_valid_batch_p2p_output(param_list):
|
|
177
|
+
if not isinstance(param_list, list) or not param_list:
|
|
178
|
+
return False
|
|
179
|
+
if not isinstance(param_list[0], list) or not param_list[0]:
|
|
180
|
+
return False
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def _extract_batch_p2p_info(node, node_data):
|
|
185
|
+
param_list = node_data.get(Const.OUTPUT, [])
|
|
186
|
+
# 数据格式:"output": [[{param1}, {param2}, ...]]
|
|
187
|
+
if GraphBuilder._is_valid_batch_p2p_output(param_list):
|
|
188
|
+
for param in param_list[0]:
|
|
189
|
+
info = {GraphConst.OP: param.get(GraphConst.OP), GraphConst.PEER: param.get(GraphConst.PEER),
|
|
190
|
+
GraphConst.GROUP_ID: param.get(GraphConst.GROUP_ID)}
|
|
191
|
+
node.batch_p2p_info.append(info)
|
|
192
|
+
|
|
112
193
|
@staticmethod
|
|
113
194
|
def _collect_apis_between_modules(graph):
|
|
114
195
|
"""
|
|
@@ -156,10 +237,12 @@ class GraphBuilder:
|
|
|
156
237
|
|
|
157
238
|
|
|
158
239
|
class GraphExportConfig:
|
|
159
|
-
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''
|
|
240
|
+
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
241
|
+
overflow_check=False):
|
|
160
242
|
self.graph_n = graph_n
|
|
161
243
|
self.graph_b = graph_b
|
|
162
244
|
self.tool_tip = tool_tip
|
|
163
245
|
self.node_colors = node_colors
|
|
164
246
|
self.micro_steps = micro_steps
|
|
165
247
|
self.task = task
|
|
248
|
+
self.overflow_check = overflow_check
|
|
@@ -18,11 +18,12 @@ from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
|
18
18
|
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
19
|
from msprobe.visualization.utils import GraphConst
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
21
22
|
|
|
22
23
|
# 用于将节点名字解析成对应的NodeOp的规则
|
|
23
24
|
op_patterns = [
|
|
24
25
|
# NodeOp.module
|
|
25
|
-
r'^(Module.|Cell
|
|
26
|
+
r'^(Module.|Cell.|optimizer|clip_grad)',
|
|
26
27
|
# NodeOp.function_api
|
|
27
28
|
r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)'
|
|
28
29
|
]
|
|
@@ -50,12 +51,14 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
|
50
51
|
framework: 框架类型, pytorch或mindspore
|
|
51
52
|
is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
|
|
52
53
|
"""
|
|
54
|
+
mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
|
|
55
|
+
|
|
53
56
|
if framework == Const.PT_FRAMEWORK:
|
|
54
57
|
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
55
|
-
return PTComparator().do_multi_process(dump_path_param, csv_path)
|
|
58
|
+
return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
|
|
56
59
|
else:
|
|
57
|
-
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
58
|
-
ms_comparator = MSComparator()
|
|
60
|
+
from msprobe.mindspore.compare.ms_compare import MSComparator, MappingConfig
|
|
61
|
+
ms_comparator = MSComparator(mode_config, MappingConfig())
|
|
59
62
|
ms_comparator.cross_frame = is_cross_frame
|
|
60
63
|
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
61
64
|
|
|
@@ -105,11 +108,25 @@ def compare_data(data_dict_list1, data_dict_list2):
|
|
|
105
108
|
return True
|
|
106
109
|
|
|
107
110
|
|
|
108
|
-
def
|
|
111
|
+
def compare_data_fuzzy(data_dict_list1, data_dict_list2):
|
|
112
|
+
"""
|
|
113
|
+
模糊匹配,仅校验参数shape是否一致
|
|
114
|
+
"""
|
|
115
|
+
for x, y in zip(data_dict_list1.values(), data_dict_list2.values()):
|
|
116
|
+
x_shape = x.get(Const.SHAPE)
|
|
117
|
+
y_shape = y.get(Const.SHAPE)
|
|
118
|
+
if x_shape != y_shape:
|
|
119
|
+
return False
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def format_node_data(data_dict, node_id=None):
|
|
109
124
|
"""
|
|
110
|
-
|
|
125
|
+
删除节点数据中不需要展示的字段
|
|
111
126
|
"""
|
|
112
127
|
del_list = ['requires_grad', 'full_op_name']
|
|
128
|
+
if node_id and GraphConst.BATCH_P2P in node_id:
|
|
129
|
+
del_list.extend(['op', 'peer', 'tag', 'group_id'])
|
|
113
130
|
for _, value in data_dict.items():
|
|
114
131
|
if not isinstance(value, dict):
|
|
115
132
|
continue
|
|
@@ -179,6 +196,13 @@ def _format_data(data_dict):
|
|
|
179
196
|
"""
|
|
180
197
|
pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
|
|
181
198
|
all_null = False
|
|
199
|
+
|
|
200
|
+
keys_to_keep = ['type', 'group_ranks', 'group_id', 'data_name']
|
|
201
|
+
if data_dict.get('type') == 'torch.ProcessGroup':
|
|
202
|
+
keys_to_remove = [key for key in data_dict if key not in keys_to_keep]
|
|
203
|
+
for key in keys_to_remove:
|
|
204
|
+
del data_dict[key]
|
|
205
|
+
|
|
182
206
|
for key, value in data_dict.items():
|
|
183
207
|
if isinstance(value, str):
|
|
184
208
|
# 将单引号删掉,None换成null避免前端解析错误
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import re
|
|
16
17
|
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
17
18
|
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
|
|
18
19
|
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
@@ -22,18 +23,23 @@ from msprobe.core.common.const import Const
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class GraphComparator:
|
|
25
|
-
def __init__(self, graphs, dump_path_param,
|
|
26
|
+
def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
|
|
26
27
|
self.graph_n = graphs[0]
|
|
27
28
|
self.graph_b = graphs[1]
|
|
28
|
-
self._parse_param(dump_path_param, output_path)
|
|
29
|
-
self.framework = framework
|
|
29
|
+
self._parse_param(dump_path_param, args.output_path)
|
|
30
|
+
self.framework = args.framework
|
|
30
31
|
self.mapping_dict = mapping_dict
|
|
32
|
+
self.fuzzy_match = args.fuzzy_match
|
|
33
|
+
self.pattern = re.compile(r'\.\d+\.')
|
|
31
34
|
|
|
32
35
|
def compare(self):
|
|
33
36
|
"""
|
|
34
37
|
比较函数,初始化结束后单独调用。比较结果写入graph_n
|
|
35
38
|
"""
|
|
36
|
-
self.
|
|
39
|
+
if self.fuzzy_match:
|
|
40
|
+
self._compare_nodes_fuzzy(self.graph_n.root)
|
|
41
|
+
else:
|
|
42
|
+
self._compare_nodes(self.graph_n.root)
|
|
37
43
|
self._postcompare()
|
|
38
44
|
|
|
39
45
|
def add_compare_result_to_node(self, node, compare_result_list):
|
|
@@ -60,8 +66,6 @@ class GraphComparator:
|
|
|
60
66
|
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
61
67
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
62
68
|
node.data.update(other_dict)
|
|
63
|
-
if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
|
|
64
|
-
node.get_suggestions()
|
|
65
69
|
|
|
66
70
|
def _parse_param(self, dump_path_param, output_path):
|
|
67
71
|
self.dump_path_param = dump_path_param
|
|
@@ -82,8 +86,6 @@ class GraphComparator:
|
|
|
82
86
|
for node in self.ma.compare_nodes:
|
|
83
87
|
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
84
88
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
85
|
-
if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
|
|
86
|
-
node.get_suggestions()
|
|
87
89
|
|
|
88
90
|
def _handle_api_collection_index(self):
|
|
89
91
|
"""
|
|
@@ -120,11 +122,59 @@ class GraphComparator:
|
|
|
120
122
|
node_n.add_link(node_b, ancestors)
|
|
121
123
|
if node_b:
|
|
122
124
|
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
123
|
-
|
|
124
|
-
[self.data_n_dict, self.data_b_dict],
|
|
125
|
-
self.stack_json_data, self.ma.compare_mode)
|
|
126
|
-
if compare_result_list:
|
|
127
|
-
self.ma.add_csv_data(compare_result_list)
|
|
128
|
-
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
125
|
+
self._get_and_add_result(node_n, node_b)
|
|
129
126
|
for subnode in node_n.subnodes:
|
|
130
127
|
self._compare_nodes(subnode)
|
|
128
|
+
|
|
129
|
+
def _compare_nodes_fuzzy(self, node_n):
|
|
130
|
+
if node_n.op != NodeOp.function_api:
|
|
131
|
+
# 模块经过模糊匹配
|
|
132
|
+
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
133
|
+
if node_b:
|
|
134
|
+
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
135
|
+
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
136
|
+
recount_result_n = self._recount_api_node(node_n)
|
|
137
|
+
recount_result_b = self._recount_api_node(node_b)
|
|
138
|
+
for recount_node_id, node_id_n in recount_result_n.items():
|
|
139
|
+
api_node_n = self.graph_n.node_map.get(node_id_n)
|
|
140
|
+
if not api_node_n:
|
|
141
|
+
continue
|
|
142
|
+
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
143
|
+
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
144
|
+
if api_node_b:
|
|
145
|
+
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
146
|
+
for sub_node in node_n.subnodes:
|
|
147
|
+
self._compare_nodes_fuzzy(sub_node)
|
|
148
|
+
|
|
149
|
+
def _get_and_add_result(self, node_n, node_b):
|
|
150
|
+
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
151
|
+
[self.data_n_dict, self.data_b_dict],
|
|
152
|
+
self.stack_json_data, self.ma.compare_mode)
|
|
153
|
+
if compare_result_list:
|
|
154
|
+
self.ma.add_csv_data(compare_result_list)
|
|
155
|
+
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
156
|
+
|
|
157
|
+
def _recount_api_node(self, node):
|
|
158
|
+
"""
|
|
159
|
+
两个匹配上的模块, 忽略各自模块下所有api的dump调用次数, 并赋予模块中的调用顺序
|
|
160
|
+
Return:
|
|
161
|
+
{赋予模块中的调用顺序的node_id: 原始node_id}
|
|
162
|
+
"""
|
|
163
|
+
recount_result = {}
|
|
164
|
+
node_count = {}
|
|
165
|
+
for sub_node in node.subnodes:
|
|
166
|
+
if sub_node.op == NodeOp.function_api:
|
|
167
|
+
# 忽略dump调用次数
|
|
168
|
+
count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
|
|
169
|
+
node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
|
|
170
|
+
# 赋予模块中的调用顺序
|
|
171
|
+
recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
|
|
172
|
+
recount_result[recount_node_id] = sub_node.id
|
|
173
|
+
return recount_result
|
|
174
|
+
|
|
175
|
+
def _process_matched_nodes(self, node_n, node_b, ancestors_n, ancestors_b):
|
|
176
|
+
ancestors_n.append(node_n.id)
|
|
177
|
+
ancestors_b.append(node_b.id)
|
|
178
|
+
node_n.matched_node_link = ancestors_b
|
|
179
|
+
node_b.matched_node_link = ancestors_n
|
|
180
|
+
self._get_and_add_result(node_n, node_b)
|
|
@@ -83,27 +83,13 @@ class ModeAdapter:
|
|
|
83
83
|
continue
|
|
84
84
|
compare_data = compare_data_dict.get(key)
|
|
85
85
|
if compare_data:
|
|
86
|
-
dtype = data_info.get(Const.DTYPE)
|
|
87
86
|
# 对应比对结果csv的列
|
|
88
87
|
key_list = GraphConst.SUMMARY_INDEX_LIST
|
|
89
88
|
headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
|
|
90
89
|
id_list = [headers.index(x) for x in key_list]
|
|
91
90
|
ModeAdapter._match_data(data_info, compare_data, key_list, id_list)
|
|
92
|
-
for
|
|
93
|
-
value = data_info.get(GraphConst.VALUE_INDEX_LIST[index])
|
|
94
|
-
value_diff = data_info.get(key_list[index])
|
|
91
|
+
for item in key_list[4:]:
|
|
95
92
|
relative_err = str2float(data_info.get(item))
|
|
96
|
-
if isinstance(value, float) and isinstance(value_diff, float) \
|
|
97
|
-
and dtype in GraphConst.SMALL_VALUES.keys():
|
|
98
|
-
small_value = GraphConst.SMALL_VALUES.get(dtype)
|
|
99
|
-
# 小值域
|
|
100
|
-
if abs(value) <= small_value:
|
|
101
|
-
data_info[item] = ToolTip.SMALL_VALUE_TIP.format(data_info.get(item),
|
|
102
|
-
GraphConst.VALUE_INDEX_LIST[index],
|
|
103
|
-
small_value)
|
|
104
|
-
relative_err = GraphConst.MIN_INDEX_KEY \
|
|
105
|
-
if abs(value_diff) <= GraphConst.SMALL_VALUES_ABS_ERROR.get(dtype) \
|
|
106
|
-
else GraphConst.MAX_INDEX_KEY
|
|
107
93
|
max_relative_err = max(max_relative_err, relative_err)
|
|
108
94
|
node_data[key] = data_info
|
|
109
95
|
max_relative_err = 1 if max_relative_err > 1 else max_relative_err
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
from msprobe.core.overflow_check.level import OverflowLevel
|
|
16
16
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
17
|
-
from msprobe.visualization.utils import
|
|
18
|
-
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data
|
|
17
|
+
from msprobe.visualization.utils import GraphConst
|
|
18
|
+
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class BaseNode:
|
|
@@ -33,6 +33,8 @@ class BaseNode:
|
|
|
33
33
|
self.stack_info = []
|
|
34
34
|
self.micro_step_id = None
|
|
35
35
|
self.overflow_level = None
|
|
36
|
+
self.matched_distributed = {}
|
|
37
|
+
self.batch_p2p_info = []
|
|
36
38
|
|
|
37
39
|
def __str__(self):
|
|
38
40
|
info = f'id:\t{self.id}'
|
|
@@ -48,16 +50,12 @@ class BaseNode:
|
|
|
48
50
|
return False
|
|
49
51
|
return True
|
|
50
52
|
|
|
51
|
-
def
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL
|
|
58
|
-
elif self.op == NodeOp.function_api:
|
|
59
|
-
self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API
|
|
60
|
-
self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL
|
|
53
|
+
def fuzzy_eq(self, other):
|
|
54
|
+
if not compare_data_fuzzy(self.input_data, other.input_data):
|
|
55
|
+
return False
|
|
56
|
+
if not compare_data_fuzzy(self.output_data, other.output_data):
|
|
57
|
+
return False
|
|
58
|
+
return True
|
|
61
59
|
|
|
62
60
|
def set_input_output(self, input_data, output_data):
|
|
63
61
|
self.input_data = input_data
|
|
@@ -67,6 +65,7 @@ class BaseNode:
|
|
|
67
65
|
if not level or not isinstance(level, OverflowLevel):
|
|
68
66
|
return
|
|
69
67
|
self.overflow_level = level
|
|
68
|
+
self.data[GraphConst.OVERFLOW_LEVEL] = self.overflow_level.value
|
|
70
69
|
|
|
71
70
|
def add_upnode(self, node):
|
|
72
71
|
"""
|
|
@@ -94,8 +93,8 @@ class BaseNode:
|
|
|
94
93
|
result = {
|
|
95
94
|
'id': self.id,
|
|
96
95
|
'node_type': self.op.value,
|
|
97
|
-
'output_data': format_node_data(self.output_data),
|
|
98
|
-
'input_data': format_node_data(self.input_data),
|
|
96
|
+
'output_data': format_node_data(self.output_data, self.id),
|
|
97
|
+
'input_data': format_node_data(self.input_data, self.id),
|
|
99
98
|
'upnode': self.upnode.id if self.upnode else 'None',
|
|
100
99
|
'subnodes': [node.id for node in self.subnodes],
|
|
101
100
|
'matched_node_link': self.matched_node_link,
|
|
@@ -104,12 +103,9 @@ class BaseNode:
|
|
|
104
103
|
}
|
|
105
104
|
if self.micro_step_id is not None:
|
|
106
105
|
result['micro_step_id'] = self.micro_step_id
|
|
107
|
-
# 是否存在overflow,并保存结果
|
|
108
|
-
if self.overflow_level and isinstance(self.overflow_level, OverflowLevel):
|
|
109
|
-
if self.data is None:
|
|
110
|
-
self.data = dict()
|
|
111
|
-
self.data['overflow_level'] = self.overflow_level.value
|
|
112
106
|
result['data'] = self.data
|
|
107
|
+
if self.matched_distributed:
|
|
108
|
+
result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
|
|
113
109
|
return result
|
|
114
110
|
|
|
115
111
|
def get_ancestors(self):
|