mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,16 +14,23 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import re
|
|
17
|
+
|
|
18
|
+
from msprobe.core.common.const import Const
|
|
19
|
+
from msprobe.core.common.file_utils import load_json
|
|
20
|
+
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
+
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
17
22
|
from msprobe.visualization.graph.graph import Graph
|
|
18
23
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
19
24
|
from msprobe.visualization.utils import save_json_file, GraphConst
|
|
20
|
-
from msprobe.visualization.builder.msprobe_adapter import get_input_output
|
|
21
|
-
from msprobe.core.common.file_utils import load_json
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class GraphBuilder:
|
|
28
|
+
backward_pattern = re.compile(r"(\.backward\.)(\d+)$")
|
|
29
|
+
# 匹配以大写字母开头,后接任意字母,并以Template(结尾
|
|
30
|
+
template_pattern = re.compile(r'\b[A-Z][a-zA-Z]*Template\(')
|
|
31
|
+
|
|
25
32
|
@staticmethod
|
|
26
|
-
def build(construct_path, data_path, stack_path, model_name='DefaultModel'):
|
|
33
|
+
def build(construct_path, data_path, stack_path, model_name='DefaultModel', complete_stack=False):
|
|
27
34
|
"""
|
|
28
35
|
GraphBuilder的对外提供的构图方法
|
|
29
36
|
Args:
|
|
@@ -31,11 +38,14 @@ class GraphBuilder:
|
|
|
31
38
|
data_path: dump.json路径
|
|
32
39
|
stack_path: stack.json路径
|
|
33
40
|
model_name: 模型名字,依赖外部输入
|
|
41
|
+
complete_stack: 完整的堆栈信息
|
|
34
42
|
Returns: Graph,代表图的数据结构
|
|
35
43
|
"""
|
|
36
44
|
construct_dict = load_json(construct_path)
|
|
37
45
|
dump_dict = load_json(data_path)
|
|
38
46
|
stack_dict = load_json(stack_path)
|
|
47
|
+
if not complete_stack:
|
|
48
|
+
GraphBuilder._simplify_stack(stack_dict)
|
|
39
49
|
data_dict = dump_dict.get(GraphConst.DATA_KEY, {})
|
|
40
50
|
graph = Graph(model_name, data_path=dump_dict.get('dump_data_dir', ''), dump_data=data_dict)
|
|
41
51
|
GraphBuilder._init_nodes(graph, construct_dict, data_dict, stack_dict)
|
|
@@ -61,8 +71,42 @@ class GraphBuilder:
|
|
|
61
71
|
result[GraphConst.MICRO_STEPS] = config.micro_steps
|
|
62
72
|
if config.task:
|
|
63
73
|
result[GraphConst.JSON_TASK_KEY] = config.task
|
|
74
|
+
result[GraphConst.OVERFLOW_CHECK] = config.overflow_check
|
|
64
75
|
save_json_file(filename, result)
|
|
65
76
|
|
|
77
|
+
@staticmethod
|
|
78
|
+
def _simplify_stack(stack_dict):
|
|
79
|
+
"""
|
|
80
|
+
精简堆栈内容,模块级保留包含"模块名("的堆栈,api级保留"xxxTemplate("的下一行堆栈
|
|
81
|
+
|
|
82
|
+
例如模块 Module.layer3.0.bn2.BatchNorm2d.forward.0,模块名为bn2,匹配"bn2(",
|
|
83
|
+
保留堆栈"File /home/models/resnet.py, line 97, in forward, \n out = self.bn2(out)"
|
|
84
|
+
|
|
85
|
+
例如Api Tensor.__iadd__.4.forward,堆栈为:
|
|
86
|
+
"File /home/wrap_tensor.py, line 61, return TensorOPTemplate(op_name, hook)(*args, **kwargs)",
|
|
87
|
+
"File /home/torchvision/models/resnet.py, line 102, in forward, \n out += identity",
|
|
88
|
+
匹配到第一行的"TensorOPTemplate(",保留下一行堆栈
|
|
89
|
+
"""
|
|
90
|
+
module_pattern = re.compile(op_patterns[0])
|
|
91
|
+
for dump_name, stack_list in stack_dict.items():
|
|
92
|
+
if not isinstance(stack_list, list):
|
|
93
|
+
continue
|
|
94
|
+
if module_pattern.match(dump_name):
|
|
95
|
+
parts = dump_name.split(Const.SEP)
|
|
96
|
+
if len(parts) < abs(Const.LAYER_NAME_INDEX):
|
|
97
|
+
continue
|
|
98
|
+
module_name = parts[Const.LAYER_NAME_INDEX]
|
|
99
|
+
for stack in stack_list:
|
|
100
|
+
if re.search(module_name + r'\(', stack):
|
|
101
|
+
stack_list = [stack]
|
|
102
|
+
break
|
|
103
|
+
else:
|
|
104
|
+
for index, stack in enumerate(stack_list):
|
|
105
|
+
if GraphBuilder.template_pattern.search(stack) and index < len(stack_list) - 1:
|
|
106
|
+
stack_list = [stack_list[index + 1]]
|
|
107
|
+
break
|
|
108
|
+
stack_dict[dump_name] = stack_list
|
|
109
|
+
|
|
66
110
|
@staticmethod
|
|
67
111
|
def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id):
|
|
68
112
|
"""
|
|
@@ -104,6 +148,17 @@ class GraphBuilder:
|
|
|
104
148
|
input_data, output_data = get_input_output(node_data, node.id)
|
|
105
149
|
# 更新数据
|
|
106
150
|
node.set_input_output(input_data, output_data)
|
|
151
|
+
# 反向节点使用对应前向节点的堆栈信息
|
|
152
|
+
# 模块命名举例:Module.module.module.GPTModel.backward.0; API命名举例:Tensor.permute.1.backward
|
|
153
|
+
if (not node_stack_info and
|
|
154
|
+
(GraphBuilder.backward_pattern.search(name) or name.endswith(f'{Const.SEP}{Const.BACKWARD}'))):
|
|
155
|
+
forward_node = graph.get_node(
|
|
156
|
+
# 同名模块全局唯一,无论调用几次堆栈信息都一致,直接使用编号0的同名模块堆栈信息,避免遗漏
|
|
157
|
+
GraphBuilder.backward_pattern.sub(f'{Const.SEP}{Const.FORWARD}{Const.SEP}0', name)) \
|
|
158
|
+
if GraphBuilder.backward_pattern.search(name) \
|
|
159
|
+
else graph.get_node(name.replace(Const.BACKWARD, Const.FORWARD))
|
|
160
|
+
node_stack_info = forward_node.stack_info if forward_node \
|
|
161
|
+
else ['This backward node cannot find the forward node and cannot retrieve stack information.']
|
|
107
162
|
node.stack_info = node_stack_info
|
|
108
163
|
# 添加节点
|
|
109
164
|
node.add_upnode(upnode)
|
|
@@ -156,10 +211,12 @@ class GraphBuilder:
|
|
|
156
211
|
|
|
157
212
|
|
|
158
213
|
class GraphExportConfig:
|
|
159
|
-
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task=''
|
|
214
|
+
def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='',
|
|
215
|
+
overflow_check=False):
|
|
160
216
|
self.graph_n = graph_n
|
|
161
217
|
self.graph_b = graph_b
|
|
162
218
|
self.tool_tip = tool_tip
|
|
163
219
|
self.node_colors = node_colors
|
|
164
220
|
self.micro_steps = micro_steps
|
|
165
221
|
self.task = task
|
|
222
|
+
self.overflow_check = overflow_check
|
|
@@ -18,6 +18,7 @@ from msprobe.core.compare.acc_compare import read_op, merge_tensor, get_accuracy
|
|
|
18
18
|
from msprobe.core.common.utils import set_dump_path, get_dump_mode
|
|
19
19
|
from msprobe.visualization.utils import GraphConst
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
|
+
from msprobe.core.compare.acc_compare import ModeConfig
|
|
21
22
|
|
|
22
23
|
# 用于将节点名字解析成对应的NodeOp的规则
|
|
23
24
|
op_patterns = [
|
|
@@ -50,12 +51,14 @@ def run_real_data(dump_path_param, csv_path, framework, is_cross_frame=False):
|
|
|
50
51
|
framework: 框架类型, pytorch或mindspore
|
|
51
52
|
is_cross_frame: 是否进行跨框架比对,仅支持mindspore比pytorch, 其中pytorch为标杆
|
|
52
53
|
"""
|
|
54
|
+
mode_config = ModeConfig(stack_mode=False, auto_analyze=True, fuzzy_match=False, dump_mode=Const.ALL)
|
|
55
|
+
|
|
53
56
|
if framework == Const.PT_FRAMEWORK:
|
|
54
57
|
from msprobe.pytorch.compare.pt_compare import PTComparator
|
|
55
|
-
return PTComparator().do_multi_process(dump_path_param, csv_path)
|
|
58
|
+
return PTComparator(mode_config).do_multi_process(dump_path_param, csv_path)
|
|
56
59
|
else:
|
|
57
60
|
from msprobe.mindspore.compare.ms_compare import MSComparator
|
|
58
|
-
ms_comparator = MSComparator()
|
|
61
|
+
ms_comparator = MSComparator(mode_config)
|
|
59
62
|
ms_comparator.cross_frame = is_cross_frame
|
|
60
63
|
return ms_comparator.do_multi_process(dump_path_param, csv_path)
|
|
61
64
|
|
|
@@ -105,6 +108,18 @@ def compare_data(data_dict_list1, data_dict_list2):
|
|
|
105
108
|
return True
|
|
106
109
|
|
|
107
110
|
|
|
111
|
+
def compare_data_fuzzy(data_dict_list1, data_dict_list2):
|
|
112
|
+
"""
|
|
113
|
+
模糊匹配,仅校验参数shape是否一致
|
|
114
|
+
"""
|
|
115
|
+
for x, y in zip(data_dict_list1.values(), data_dict_list2.values()):
|
|
116
|
+
x_shape = x.get(Const.SHAPE)
|
|
117
|
+
y_shape = y.get(Const.SHAPE)
|
|
118
|
+
if x_shape != y_shape:
|
|
119
|
+
return False
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
|
|
108
123
|
def format_node_data(data_dict):
|
|
109
124
|
"""
|
|
110
125
|
批量进行节点数据的输出
|
|
@@ -179,6 +194,13 @@ def _format_data(data_dict):
|
|
|
179
194
|
"""
|
|
180
195
|
pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$'
|
|
181
196
|
all_null = False
|
|
197
|
+
|
|
198
|
+
keys_to_keep = ['type', 'group_ranks', 'group_id', 'data_name']
|
|
199
|
+
if data_dict.get('type') == 'torch.ProcessGroup':
|
|
200
|
+
keys_to_remove = [key for key in data_dict if key not in keys_to_keep]
|
|
201
|
+
for key in keys_to_remove:
|
|
202
|
+
del data_dict[key]
|
|
203
|
+
|
|
182
204
|
for key, value in data_dict.items():
|
|
183
205
|
if isinstance(value, str):
|
|
184
206
|
# 将单引号删掉,None换成null避免前端解析错误
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import re
|
|
16
17
|
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
17
18
|
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
|
|
18
19
|
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
@@ -22,18 +23,23 @@ from msprobe.core.common.const import Const
|
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
class GraphComparator:
|
|
25
|
-
def __init__(self, graphs, dump_path_param,
|
|
26
|
+
def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
|
|
26
27
|
self.graph_n = graphs[0]
|
|
27
28
|
self.graph_b = graphs[1]
|
|
28
|
-
self._parse_param(dump_path_param, output_path)
|
|
29
|
-
self.framework = framework
|
|
29
|
+
self._parse_param(dump_path_param, args.output_path)
|
|
30
|
+
self.framework = args.framework
|
|
30
31
|
self.mapping_dict = mapping_dict
|
|
32
|
+
self.fuzzy_match = args.fuzzy_match
|
|
33
|
+
self.pattern = re.compile(r'\.\d+\.')
|
|
31
34
|
|
|
32
35
|
def compare(self):
|
|
33
36
|
"""
|
|
34
37
|
比较函数,初始化结束后单独调用。比较结果写入graph_n
|
|
35
38
|
"""
|
|
36
|
-
self.
|
|
39
|
+
if self.fuzzy_match:
|
|
40
|
+
self._compare_nodes_fuzzy(self.graph_n.root)
|
|
41
|
+
else:
|
|
42
|
+
self._compare_nodes(self.graph_n.root)
|
|
37
43
|
self._postcompare()
|
|
38
44
|
|
|
39
45
|
def add_compare_result_to_node(self, node, compare_result_list):
|
|
@@ -60,8 +66,6 @@ class GraphComparator:
|
|
|
60
66
|
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
61
67
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
62
68
|
node.data.update(other_dict)
|
|
63
|
-
if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
|
|
64
|
-
node.get_suggestions()
|
|
65
69
|
|
|
66
70
|
def _parse_param(self, dump_path_param, output_path):
|
|
67
71
|
self.dump_path_param = dump_path_param
|
|
@@ -82,8 +86,6 @@ class GraphComparator:
|
|
|
82
86
|
for node in self.ma.compare_nodes:
|
|
83
87
|
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
84
88
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
85
|
-
if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index):
|
|
86
|
-
node.get_suggestions()
|
|
87
89
|
|
|
88
90
|
def _handle_api_collection_index(self):
|
|
89
91
|
"""
|
|
@@ -120,11 +122,59 @@ class GraphComparator:
|
|
|
120
122
|
node_n.add_link(node_b, ancestors)
|
|
121
123
|
if node_b:
|
|
122
124
|
# 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口
|
|
123
|
-
|
|
124
|
-
[self.data_n_dict, self.data_b_dict],
|
|
125
|
-
self.stack_json_data, self.ma.compare_mode)
|
|
126
|
-
if compare_result_list:
|
|
127
|
-
self.ma.add_csv_data(compare_result_list)
|
|
128
|
-
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
125
|
+
self._get_and_add_result(node_n, node_b)
|
|
129
126
|
for subnode in node_n.subnodes:
|
|
130
127
|
self._compare_nodes(subnode)
|
|
128
|
+
|
|
129
|
+
def _compare_nodes_fuzzy(self, node_n):
|
|
130
|
+
if node_n.op != NodeOp.function_api:
|
|
131
|
+
# 模块经过模糊匹配
|
|
132
|
+
node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(node_n, self.graph_b.node_map.get(node_n.id))
|
|
133
|
+
if node_b:
|
|
134
|
+
self._process_matched_nodes(node_n, node_b, ancestors_n, ancestors_b)
|
|
135
|
+
# 匹配上的两个模块中的所有api, 忽略dump调用次数,按照名称一致+模块中的调用顺序进行匹配
|
|
136
|
+
recount_result_n = self._recount_api_node(node_n)
|
|
137
|
+
recount_result_b = self._recount_api_node(node_b)
|
|
138
|
+
for recount_node_id, node_id_n in recount_result_n.items():
|
|
139
|
+
api_node_n = self.graph_n.node_map.get(node_id_n)
|
|
140
|
+
if not api_node_n:
|
|
141
|
+
continue
|
|
142
|
+
api_node_b, ancestors_n, ancestors_b = Graph.fuzzy_match(
|
|
143
|
+
api_node_n, self.graph_b.node_map.get(recount_result_b.get(recount_node_id)))
|
|
144
|
+
if api_node_b:
|
|
145
|
+
self._process_matched_nodes(api_node_n, api_node_b, ancestors_n, ancestors_b)
|
|
146
|
+
for sub_node in node_n.subnodes:
|
|
147
|
+
self._compare_nodes_fuzzy(sub_node)
|
|
148
|
+
|
|
149
|
+
def _get_and_add_result(self, node_n, node_b):
|
|
150
|
+
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
151
|
+
[self.data_n_dict, self.data_b_dict],
|
|
152
|
+
self.stack_json_data, self.ma.compare_mode)
|
|
153
|
+
if compare_result_list:
|
|
154
|
+
self.ma.add_csv_data(compare_result_list)
|
|
155
|
+
self.add_compare_result_to_node(node_n, compare_result_list)
|
|
156
|
+
|
|
157
|
+
def _recount_api_node(self, node):
|
|
158
|
+
"""
|
|
159
|
+
两个匹配上的模块, 忽略各自模块下所有api的dump调用次数, 并赋予模块中的调用顺序
|
|
160
|
+
Return:
|
|
161
|
+
{赋予模块中的调用顺序的node_id: 原始node_id}
|
|
162
|
+
"""
|
|
163
|
+
recount_result = {}
|
|
164
|
+
node_count = {}
|
|
165
|
+
for sub_node in node.subnodes:
|
|
166
|
+
if sub_node.op == NodeOp.function_api:
|
|
167
|
+
# 忽略dump调用次数
|
|
168
|
+
count_removed_id = self.pattern.sub(Const.SEP, sub_node.id)
|
|
169
|
+
node_count[count_removed_id] = node_count.get(count_removed_id, 0) + 1
|
|
170
|
+
# 赋予模块中的调用顺序
|
|
171
|
+
recount_node_id = count_removed_id + str(node_count.get(count_removed_id))
|
|
172
|
+
recount_result[recount_node_id] = sub_node.id
|
|
173
|
+
return recount_result
|
|
174
|
+
|
|
175
|
+
def _process_matched_nodes(self, node_n, node_b, ancestors_n, ancestors_b):
|
|
176
|
+
ancestors_n.append(node_n.id)
|
|
177
|
+
ancestors_b.append(node_b.id)
|
|
178
|
+
node_n.matched_node_link = ancestors_b
|
|
179
|
+
node_b.matched_node_link = ancestors_n
|
|
180
|
+
self._get_and_add_result(node_n, node_b)
|
|
@@ -83,27 +83,13 @@ class ModeAdapter:
|
|
|
83
83
|
continue
|
|
84
84
|
compare_data = compare_data_dict.get(key)
|
|
85
85
|
if compare_data:
|
|
86
|
-
dtype = data_info.get(Const.DTYPE)
|
|
87
86
|
# 对应比对结果csv的列
|
|
88
87
|
key_list = GraphConst.SUMMARY_INDEX_LIST
|
|
89
88
|
headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER
|
|
90
89
|
id_list = [headers.index(x) for x in key_list]
|
|
91
90
|
ModeAdapter._match_data(data_info, compare_data, key_list, id_list)
|
|
92
|
-
for
|
|
93
|
-
value = data_info.get(GraphConst.VALUE_INDEX_LIST[index])
|
|
94
|
-
value_diff = data_info.get(key_list[index])
|
|
91
|
+
for item in key_list[4:]:
|
|
95
92
|
relative_err = str2float(data_info.get(item))
|
|
96
|
-
if isinstance(value, float) and isinstance(value_diff, float) \
|
|
97
|
-
and dtype in GraphConst.SMALL_VALUES.keys():
|
|
98
|
-
small_value = GraphConst.SMALL_VALUES.get(dtype)
|
|
99
|
-
# 小值域
|
|
100
|
-
if abs(value) <= small_value:
|
|
101
|
-
data_info[item] = ToolTip.SMALL_VALUE_TIP.format(data_info.get(item),
|
|
102
|
-
GraphConst.VALUE_INDEX_LIST[index],
|
|
103
|
-
small_value)
|
|
104
|
-
relative_err = GraphConst.MIN_INDEX_KEY \
|
|
105
|
-
if abs(value_diff) <= GraphConst.SMALL_VALUES_ABS_ERROR.get(dtype) \
|
|
106
|
-
else GraphConst.MAX_INDEX_KEY
|
|
107
93
|
max_relative_err = max(max_relative_err, relative_err)
|
|
108
94
|
node_data[key] = data_info
|
|
109
95
|
max_relative_err = 1 if max_relative_err > 1 else max_relative_err
|
|
@@ -14,8 +14,8 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
from msprobe.core.overflow_check.level import OverflowLevel
|
|
16
16
|
from msprobe.visualization.graph.node_op import NodeOp
|
|
17
|
-
from msprobe.visualization.utils import
|
|
18
|
-
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data
|
|
17
|
+
from msprobe.visualization.utils import GraphConst
|
|
18
|
+
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class BaseNode:
|
|
@@ -33,6 +33,7 @@ class BaseNode:
|
|
|
33
33
|
self.stack_info = []
|
|
34
34
|
self.micro_step_id = None
|
|
35
35
|
self.overflow_level = None
|
|
36
|
+
self.matched_distributed = {}
|
|
36
37
|
|
|
37
38
|
def __str__(self):
|
|
38
39
|
info = f'id:\t{self.id}'
|
|
@@ -48,16 +49,12 @@ class BaseNode:
|
|
|
48
49
|
return False
|
|
49
50
|
return True
|
|
50
51
|
|
|
51
|
-
def
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL
|
|
58
|
-
elif self.op == NodeOp.function_api:
|
|
59
|
-
self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API
|
|
60
|
-
self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL
|
|
52
|
+
def fuzzy_eq(self, other):
|
|
53
|
+
if not compare_data_fuzzy(self.input_data, other.input_data):
|
|
54
|
+
return False
|
|
55
|
+
if not compare_data_fuzzy(self.output_data, other.output_data):
|
|
56
|
+
return False
|
|
57
|
+
return True
|
|
61
58
|
|
|
62
59
|
def set_input_output(self, input_data, output_data):
|
|
63
60
|
self.input_data = input_data
|
|
@@ -67,6 +64,7 @@ class BaseNode:
|
|
|
67
64
|
if not level or not isinstance(level, OverflowLevel):
|
|
68
65
|
return
|
|
69
66
|
self.overflow_level = level
|
|
67
|
+
self.data[GraphConst.OVERFLOW_LEVEL] = self.overflow_level.value
|
|
70
68
|
|
|
71
69
|
def add_upnode(self, node):
|
|
72
70
|
"""
|
|
@@ -104,12 +102,9 @@ class BaseNode:
|
|
|
104
102
|
}
|
|
105
103
|
if self.micro_step_id is not None:
|
|
106
104
|
result['micro_step_id'] = self.micro_step_id
|
|
107
|
-
# 是否存在overflow,并保存结果
|
|
108
|
-
if self.overflow_level and isinstance(self.overflow_level, OverflowLevel):
|
|
109
|
-
if self.data is None:
|
|
110
|
-
self.data = dict()
|
|
111
|
-
self.data['overflow_level'] = self.overflow_level.value
|
|
112
105
|
result['data'] = self.data
|
|
106
|
+
if self.matched_distributed:
|
|
107
|
+
result[GraphConst.MATCHED_DISTRIBUTED] = self.matched_distributed
|
|
113
108
|
return result
|
|
114
109
|
|
|
115
110
|
def get_ancestors(self):
|