mindstudio-probe 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +168 -150
- msprobe/README.md +27 -22
- msprobe/core/common/const.py +129 -60
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +25 -2
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +43 -33
- msprobe/core/compare/acc_compare.py +43 -74
- msprobe/core/compare/check.py +2 -6
- msprobe/core/compare/highlight.py +2 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
- msprobe/core/compare/merge_result/merge_result.py +16 -9
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/multiprocessing_compute.py +19 -12
- msprobe/core/compare/npy_compare.py +30 -12
- msprobe/core/compare/utils.py +30 -10
- msprobe/core/data_dump/api_registry.py +176 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +94 -10
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +33 -33
- msprobe/core/data_dump/data_processor/pytorch_processor.py +99 -18
- msprobe/core/data_dump/json_writer.py +61 -40
- msprobe/core/grad_probe/constant.py +1 -0
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/docs/01.installation.md +27 -1
- msprobe/docs/02.config_introduction.md +27 -23
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +103 -16
- msprobe/docs/06.data_dump_MindSpore.md +76 -32
- msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
- msprobe/docs/09.accuracy_checker_MindSpore.md +5 -3
- msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
- msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
- msprobe/docs/12.overflow_check_PyTorch.md +3 -1
- msprobe/docs/13.overflow_check_MindSpore.md +4 -2
- msprobe/docs/14.data_parse_PyTorch.md +1 -7
- msprobe/docs/18.online_dispatch.md +1 -1
- msprobe/docs/19.monitor.md +332 -273
- msprobe/docs/21.visualization_PyTorch.md +42 -13
- msprobe/docs/22.visualization_MindSpore.md +43 -13
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/27.dump_json_instruction.md +301 -27
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +4 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +32 -7
- msprobe/mindspore/api_accuracy_checker/api_runner.py +70 -22
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +130 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/const.py +61 -0
- msprobe/mindspore/common/utils.py +48 -18
- msprobe/mindspore/compare/ms_compare.py +27 -19
- msprobe/mindspore/compare/ms_graph_compare.py +6 -5
- msprobe/mindspore/debugger/debugger_config.py +31 -6
- msprobe/mindspore/debugger/precision_debugger.py +45 -14
- msprobe/mindspore/dump/dump_tool_factory.py +5 -3
- msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
- msprobe/mindspore/dump/jit_dump.py +21 -15
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +2 -0
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +873 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +309 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/service.py +114 -34
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +12 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/{parse.py → bench_functions/mish.py} +6 -4
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +50 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +97 -4
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +24 -1
- msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +21 -35
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +8 -2
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/hook_module/api_register.py +131 -0
- msprobe/pytorch/hook_module/hook_module.py +19 -14
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +173 -75
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +18 -14
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +238 -193
- msprobe/pytorch/monitor/module_metric.py +9 -6
- msprobe/pytorch/monitor/optimizer_collect.py +100 -67
- msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
- msprobe/pytorch/monitor/utils.py +76 -44
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +9 -0
- msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +2 -1
- msprobe/pytorch/pt_config.py +30 -29
- msprobe/pytorch/service.py +114 -32
- msprobe/visualization/builder/graph_builder.py +75 -10
- msprobe/visualization/builder/msprobe_adapter.py +7 -6
- msprobe/visualization/compare/graph_comparator.py +42 -38
- msprobe/visualization/compare/mode_adapter.py +0 -19
- msprobe/visualization/graph/base_node.py +11 -3
- msprobe/visualization/graph/distributed_analyzer.py +71 -3
- msprobe/visualization/graph/graph.py +0 -11
- msprobe/visualization/graph/node_op.py +4 -3
- msprobe/visualization/graph_service.py +4 -5
- msprobe/visualization/utils.py +12 -35
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -205
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -75
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -17,12 +17,14 @@ import re
|
|
|
17
17
|
from msprobe.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data
|
|
18
18
|
from msprobe.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df
|
|
19
19
|
from msprobe.visualization.graph.graph import Graph, NodeOp
|
|
20
|
-
from msprobe.visualization.graph.node_colors import NodeColors
|
|
21
20
|
from msprobe.visualization.compare.mode_adapter import ModeAdapter
|
|
22
21
|
from msprobe.core.common.const import Const
|
|
22
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class GraphComparator:
|
|
26
|
+
MAX_DEPTH = 1000
|
|
27
|
+
|
|
26
28
|
def __init__(self, graphs, dump_path_param, args, mapping_dict=None):
|
|
27
29
|
self.graph_n = graphs[0]
|
|
28
30
|
self.graph_b = graphs[1]
|
|
@@ -41,7 +43,7 @@ class GraphComparator:
|
|
|
41
43
|
else:
|
|
42
44
|
self._compare_nodes(self.graph_n.root)
|
|
43
45
|
self._postcompare()
|
|
44
|
-
|
|
46
|
+
|
|
45
47
|
def add_compare_result_to_node(self, node, compare_result_list):
|
|
46
48
|
"""
|
|
47
49
|
将比对结果添加到节点的输入输出数据中
|
|
@@ -66,43 +68,8 @@ class GraphComparator:
|
|
|
66
68
|
self.ma.parse_result(node, [compare_in_dict, compare_out_dict]))
|
|
67
69
|
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
68
70
|
node.data.update(other_dict)
|
|
69
|
-
|
|
70
|
-
def _parse_param(self, dump_path_param, output_path):
|
|
71
|
-
self.dump_path_param = dump_path_param
|
|
72
|
-
self.output_path = output_path
|
|
73
|
-
compare_mode = get_compare_mode(self.dump_path_param)
|
|
74
|
-
self.ma = ModeAdapter(compare_mode)
|
|
75
|
-
self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
|
|
76
|
-
self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
|
|
77
|
-
self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
|
|
78
|
-
|
|
79
|
-
def _postcompare(self):
|
|
80
|
-
self._handle_api_collection_index()
|
|
81
|
-
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
82
|
-
return
|
|
83
|
-
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
84
|
-
df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
|
|
85
|
-
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
86
|
-
for node in self.ma.compare_nodes:
|
|
87
|
-
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
88
|
-
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
89
|
-
|
|
90
|
-
def _handle_api_collection_index(self):
|
|
91
|
-
"""
|
|
92
|
-
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
93
|
-
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
94
|
-
"""
|
|
95
|
-
for node in self.graph_n.root.subnodes:
|
|
96
|
-
if node.op == NodeOp.api_collection:
|
|
97
|
-
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
98
|
-
else GraphConst.MIN_INDEX_KEY
|
|
99
|
-
for api in node.subnodes:
|
|
100
|
-
precision_index = min(precision_index,
|
|
101
|
-
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
102
|
-
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
103
|
-
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
104
|
-
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
105
71
|
|
|
72
|
+
@recursion_depth_decorator('GraphComparator._compare_nodes', max_depth=MAX_DEPTH)
|
|
106
73
|
def _compare_nodes(self, node_n):
|
|
107
74
|
"""
|
|
108
75
|
递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比
|
|
@@ -126,6 +93,7 @@ class GraphComparator:
|
|
|
126
93
|
for subnode in node_n.subnodes:
|
|
127
94
|
self._compare_nodes(subnode)
|
|
128
95
|
|
|
96
|
+
@recursion_depth_decorator('GraphComparator._compare_nodes_fuzzy', max_depth=MAX_DEPTH)
|
|
129
97
|
def _compare_nodes_fuzzy(self, node_n):
|
|
130
98
|
if node_n.op != NodeOp.function_api:
|
|
131
99
|
# 模块经过模糊匹配
|
|
@@ -146,6 +114,42 @@ class GraphComparator:
|
|
|
146
114
|
for sub_node in node_n.subnodes:
|
|
147
115
|
self._compare_nodes_fuzzy(sub_node)
|
|
148
116
|
|
|
117
|
+
def _parse_param(self, dump_path_param, output_path):
|
|
118
|
+
self.dump_path_param = dump_path_param
|
|
119
|
+
self.output_path = output_path
|
|
120
|
+
compare_mode = get_compare_mode(self.dump_path_param)
|
|
121
|
+
self.ma = ModeAdapter(compare_mode)
|
|
122
|
+
self.data_n_dict = load_data_json_file(dump_path_param.get('npu_json_path'))
|
|
123
|
+
self.data_b_dict = load_data_json_file(dump_path_param.get('bench_json_path'))
|
|
124
|
+
self.stack_json_data = load_json_file(dump_path_param.get('stack_json_path'))
|
|
125
|
+
|
|
126
|
+
def _postcompare(self):
|
|
127
|
+
self._handle_api_collection_index()
|
|
128
|
+
if not self.ma.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
129
|
+
return
|
|
130
|
+
df = get_csv_df(True, self.ma.csv_data, self.ma.compare_mode)
|
|
131
|
+
df = run_real_data(self.dump_path_param, df, self.framework, True if self.mapping_dict else False)
|
|
132
|
+
compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()}
|
|
133
|
+
for node in self.ma.compare_nodes:
|
|
134
|
+
precision_index, _ = self.ma.parse_result(node, [compare_data_dict])
|
|
135
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
136
|
+
|
|
137
|
+
def _handle_api_collection_index(self):
|
|
138
|
+
"""
|
|
139
|
+
api集合的指标, md5模式使用集合中所有api最小的指标,statistics和tensor模式使用集合中所有api最大的指标
|
|
140
|
+
md5模式下指标为0代表最差,statistics和tensor模式下指标为1代表最差
|
|
141
|
+
"""
|
|
142
|
+
for node in self.graph_n.root.subnodes:
|
|
143
|
+
if node.op == NodeOp.api_collection:
|
|
144
|
+
precision_index = GraphConst.MAX_INDEX_KEY if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
145
|
+
else GraphConst.MIN_INDEX_KEY
|
|
146
|
+
for api in node.subnodes:
|
|
147
|
+
precision_index = min(precision_index,
|
|
148
|
+
api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MAX_INDEX_KEY)) \
|
|
149
|
+
if self.ma.compare_mode == GraphConst.MD5_COMPARE \
|
|
150
|
+
else max(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, GraphConst.MIN_INDEX_KEY))
|
|
151
|
+
node.data[GraphConst.JSON_INDEX_KEY] = precision_index
|
|
152
|
+
|
|
149
153
|
def _get_and_add_result(self, node_n, node_b):
|
|
150
154
|
compare_result_list = compare_node([node_n.id, node_b.id],
|
|
151
155
|
[self.data_n_dict, self.data_b_dict],
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import json
|
|
17
|
-
import math
|
|
18
17
|
from msprobe.core.common.const import CompareConst, Const
|
|
19
18
|
from msprobe.visualization.utils import ToolTip, GraphConst, str2float
|
|
20
19
|
|
|
@@ -157,24 +156,6 @@ class ModeAdapter:
|
|
|
157
156
|
return
|
|
158
157
|
self.csv_data.extend(compare_result_list)
|
|
159
158
|
|
|
160
|
-
def add_error_key(self, node_data):
|
|
161
|
-
"""
|
|
162
|
-
根据不同的模式进行提供不同错误信息
|
|
163
|
-
"""
|
|
164
|
-
for key, value in node_data.items():
|
|
165
|
-
if not isinstance(value, dict):
|
|
166
|
-
continue
|
|
167
|
-
if self.compare_mode == GraphConst.SUMMARY_COMPARE:
|
|
168
|
-
message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
|
|
169
|
-
CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
|
|
170
|
-
elif self.compare_mode == GraphConst.REAL_DATA_COMPARE:
|
|
171
|
-
message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]
|
|
172
|
-
else:
|
|
173
|
-
# 输出件优化
|
|
174
|
-
message = []
|
|
175
|
-
value[GraphConst.ERROR_KEY] = message
|
|
176
|
-
node_data[key] = value
|
|
177
|
-
|
|
178
159
|
def get_tool_tip(self):
|
|
179
160
|
"""
|
|
180
161
|
用于前端展示字段的具体含义
|
|
@@ -12,10 +12,11 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
|
|
15
16
|
from msprobe.core.overflow_check.level import OverflowLevel
|
|
16
|
-
from msprobe.visualization.graph.node_op import NodeOp
|
|
17
17
|
from msprobe.visualization.utils import GraphConst
|
|
18
18
|
from msprobe.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_data_fuzzy
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class BaseNode:
|
|
@@ -34,6 +35,7 @@ class BaseNode:
|
|
|
34
35
|
self.micro_step_id = None
|
|
35
36
|
self.overflow_level = None
|
|
36
37
|
self.matched_distributed = {}
|
|
38
|
+
self.batch_p2p_info = []
|
|
37
39
|
|
|
38
40
|
def __str__(self):
|
|
39
41
|
info = f'id:\t{self.id}'
|
|
@@ -92,8 +94,8 @@ class BaseNode:
|
|
|
92
94
|
result = {
|
|
93
95
|
'id': self.id,
|
|
94
96
|
'node_type': self.op.value,
|
|
95
|
-
'output_data': format_node_data(self.output_data),
|
|
96
|
-
'input_data': format_node_data(self.input_data),
|
|
97
|
+
'output_data': format_node_data(self.output_data, self.id),
|
|
98
|
+
'input_data': format_node_data(self.input_data, self.id),
|
|
97
99
|
'upnode': self.upnode.id if self.upnode else 'None',
|
|
98
100
|
'subnodes': [node.id for node in self.subnodes],
|
|
99
101
|
'matched_node_link': self.matched_node_link,
|
|
@@ -113,7 +115,13 @@ class BaseNode:
|
|
|
113
115
|
"""
|
|
114
116
|
ancestors = []
|
|
115
117
|
current_node = self.upnode
|
|
118
|
+
seen_nodes = set()
|
|
116
119
|
while current_node:
|
|
120
|
+
if current_node.id in seen_nodes:
|
|
121
|
+
logger.warning(f'Detected a cycle in the node structure and cannot get node ancestors, '
|
|
122
|
+
f'current node is {current_node.id}.')
|
|
123
|
+
return []
|
|
124
|
+
seen_nodes.add(current_node.id)
|
|
117
125
|
ancestors.append(current_node.id)
|
|
118
126
|
current_node = current_node.upnode
|
|
119
127
|
return list(reversed(ancestors))
|
|
@@ -115,7 +115,9 @@ class DistributedAnalyzer:
|
|
|
115
115
|
if not node_id.startswith(Const.DISTRIBUTED) or node.matched_distributed:
|
|
116
116
|
continue
|
|
117
117
|
api_name, distributed_type = self._get_distributed_name_and_type(node_id)
|
|
118
|
-
if
|
|
118
|
+
if api_name == GraphConst.BATCH_P2P:
|
|
119
|
+
self._batch_p2p_match(node, rank)
|
|
120
|
+
elif distributed_type == DistributedType.P2P:
|
|
119
121
|
self._p2p_match(node, rank, api_name)
|
|
120
122
|
else:
|
|
121
123
|
self._collective_match(node, rank, api_name)
|
|
@@ -138,12 +140,16 @@ class DistributedAnalyzer:
|
|
|
138
140
|
for rank, graph in self.graphs.items():
|
|
139
141
|
group_count = {}
|
|
140
142
|
group_info = {}
|
|
143
|
+
batch_p2p_count = {}
|
|
141
144
|
nodes = graph.node_map
|
|
142
145
|
for node_id, node in nodes.items():
|
|
143
146
|
if not node_id.startswith(Const.DISTRIBUTED):
|
|
144
147
|
continue
|
|
145
148
|
api_name, distributed_type = self._get_distributed_name_and_type(node_id)
|
|
146
|
-
if
|
|
149
|
+
if api_name == GraphConst.BATCH_P2P:
|
|
150
|
+
self._make_batch_p2p_mapping(node, rank, batch_p2p_count)
|
|
151
|
+
continue
|
|
152
|
+
elif distributed_type == DistributedType.P2P:
|
|
147
153
|
config_info = self.config.get(api_name)
|
|
148
154
|
target_rank = self._get_target_rank(node, rank, config_info[1])
|
|
149
155
|
if target_rank is None:
|
|
@@ -162,7 +168,32 @@ class DistributedAnalyzer:
|
|
|
162
168
|
unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(group_count.get(group_id))
|
|
163
169
|
group_info[unique_group_id] = node_id
|
|
164
170
|
group_info[node_id] = unique_group_id
|
|
165
|
-
self.group_node_mapping
|
|
171
|
+
if rank not in self.group_node_mapping:
|
|
172
|
+
self.group_node_mapping[rank] = {}
|
|
173
|
+
self.group_node_mapping[rank].update(group_info)
|
|
174
|
+
|
|
175
|
+
def _make_batch_p2p_mapping(self, node, rank, batch_p2p_count):
|
|
176
|
+
"""
|
|
177
|
+
给batch_isend_irecv接口的每个p2p内容赋予唯一标识
|
|
178
|
+
"""
|
|
179
|
+
if rank not in self.group_node_mapping:
|
|
180
|
+
self.group_node_mapping[rank] = {}
|
|
181
|
+
params = []
|
|
182
|
+
for info_dict in node.batch_p2p_info:
|
|
183
|
+
op = info_dict.get(GraphConst.OP)
|
|
184
|
+
target_rank = info_dict.get(GraphConst.PEER)
|
|
185
|
+
if op is None or target_rank is None:
|
|
186
|
+
logger.warning('Cannot get param op or peer.')
|
|
187
|
+
continue
|
|
188
|
+
group_id = op + Const.REPLACEMENT_CHARACTER + Const.RANK + str(target_rank) + \
|
|
189
|
+
Const.REPLACEMENT_CHARACTER + info_dict.get(GraphConst.GROUP_ID, '')
|
|
190
|
+
batch_p2p_count[group_id] = batch_p2p_count.get(group_id, 0) + 1
|
|
191
|
+
# 例如: isend_rank0_5a4d31ad765260ba50eb190f1f9fd163_1
|
|
192
|
+
unique_group_id = group_id + Const.REPLACEMENT_CHARACTER + str(batch_p2p_count.get(group_id))
|
|
193
|
+
params.append(unique_group_id)
|
|
194
|
+
self.group_node_mapping.get(rank)[unique_group_id] = node.id
|
|
195
|
+
if params:
|
|
196
|
+
self.group_node_mapping.get(rank)[node.id] = params
|
|
166
197
|
|
|
167
198
|
def _get_distributed_name_and_type(self, node_id):
|
|
168
199
|
if Const.SEP not in node_id:
|
|
@@ -316,3 +347,40 @@ class DistributedAnalyzer:
|
|
|
316
347
|
if nodes_info:
|
|
317
348
|
matched_distributed['nodes_info'] = nodes_info
|
|
318
349
|
node.matched_distributed = matched_distributed
|
|
350
|
+
|
|
351
|
+
def _batch_p2p_match(self, node, rank):
|
|
352
|
+
"""
|
|
353
|
+
批量点对点匹配
|
|
354
|
+
|
|
355
|
+
针对torch.distributed.batch_isend_irecv接口,其入参是一个包含点对点通信信息的集合,需要遍历集合对每个点对点通信信息进行匹配
|
|
356
|
+
:param node: 当前集体通信节点
|
|
357
|
+
:param rank: 当前节点所属rank
|
|
358
|
+
:return:
|
|
359
|
+
"""
|
|
360
|
+
unique_group_ids = self.group_node_mapping.get(rank, {}).get(node.id)
|
|
361
|
+
if not unique_group_ids:
|
|
362
|
+
return
|
|
363
|
+
matched_distributed = [] if len(unique_group_ids) > 1 else {}
|
|
364
|
+
for unique_group_id in unique_group_ids:
|
|
365
|
+
try:
|
|
366
|
+
id_info = unique_group_id.split(Const.REPLACEMENT_CHARACTER)
|
|
367
|
+
api_name = id_info[0]
|
|
368
|
+
target_api_name = self.config.get(api_name)[0]
|
|
369
|
+
target_rank = int(id_info[1].replace(Const.RANK, ''))
|
|
370
|
+
except Exception as e:
|
|
371
|
+
logger.warning(f'Failed to parse batch p2p parameter with error info: {e}.')
|
|
372
|
+
continue
|
|
373
|
+
target_node = self._get_target_node(rank, unique_group_id, api_name, target_rank, target_api_name)
|
|
374
|
+
if not target_node:
|
|
375
|
+
continue
|
|
376
|
+
communications_type = self.config.get(api_name)[2]
|
|
377
|
+
index = target_node.data.get(GraphConst.OVERFLOW_LEVEL, CompareConst.NAN) if self.overflow_check \
|
|
378
|
+
else target_node.data.get(GraphConst.JSON_INDEX_KEY, CompareConst.NAN)
|
|
379
|
+
matched_info = {
|
|
380
|
+
'communications_type': communications_type,
|
|
381
|
+
'nodes_info': {target_rank: [str(index), target_node.id]}
|
|
382
|
+
}
|
|
383
|
+
matched_distributed.append(matched_info) if isinstance(matched_distributed, list) \
|
|
384
|
+
else matched_distributed.update(matched_info)
|
|
385
|
+
if matched_distributed:
|
|
386
|
+
node.matched_distributed = matched_distributed
|
|
@@ -20,9 +20,6 @@ from msprobe.core.common.log import logger
|
|
|
20
20
|
from msprobe.core.common.const import Const
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
MAX_RECUR_LEVEL = 100
|
|
24
|
-
|
|
25
|
-
|
|
26
23
|
class Graph:
|
|
27
24
|
def __init__(self, model_name, data_path='', dump_data=None):
|
|
28
25
|
self.node_map = {}
|
|
@@ -67,7 +64,6 @@ class Graph:
|
|
|
67
64
|
ancestors_b = node_b.get_ancestors()
|
|
68
65
|
return node_b, ancestors_n, ancestors_b
|
|
69
66
|
|
|
70
|
-
|
|
71
67
|
@staticmethod
|
|
72
68
|
def fuzzy_match(node_n, node_b):
|
|
73
69
|
if not node_n or not node_b or not node_n.fuzzy_eq(node_b):
|
|
@@ -76,13 +72,6 @@ class Graph:
|
|
|
76
72
|
ancestors_b = node_b.get_ancestors()
|
|
77
73
|
return node_b, ancestors_n, ancestors_b
|
|
78
74
|
|
|
79
|
-
@staticmethod
|
|
80
|
-
def dfs(node, result):
|
|
81
|
-
info = node.to_dict()
|
|
82
|
-
result[node.id] = info
|
|
83
|
-
for subnode in node.subnodes:
|
|
84
|
-
Graph.dfs(subnode, result)
|
|
85
|
-
|
|
86
75
|
@staticmethod
|
|
87
76
|
def split_nodes_by_micro_step(nodes):
|
|
88
77
|
"""
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from enum import Enum
|
|
17
17
|
import re
|
|
18
18
|
from msprobe.visualization.builder.msprobe_adapter import op_patterns
|
|
19
|
+
from msprobe.core.common.log import logger
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
class NodeOp(Enum):
|
|
@@ -23,7 +24,6 @@ class NodeOp(Enum):
|
|
|
23
24
|
function_api = 1
|
|
24
25
|
api_collection = 9
|
|
25
26
|
|
|
26
|
-
|
|
27
27
|
@staticmethod
|
|
28
28
|
def get_node_op(node_name: str):
|
|
29
29
|
"""
|
|
@@ -32,8 +32,9 @@ class NodeOp(Enum):
|
|
|
32
32
|
for op in NodeOp:
|
|
33
33
|
index = op.value
|
|
34
34
|
if index < 0 or index >= len(op_patterns):
|
|
35
|
-
|
|
35
|
+
continue
|
|
36
36
|
pattern = op_patterns[index]
|
|
37
37
|
if re.match(pattern, node_name):
|
|
38
38
|
return op
|
|
39
|
-
|
|
39
|
+
logger.warning(f"Cannot parse node_name {node_name} into NodeOp, default parsing as module.")
|
|
40
|
+
return NodeOp.module
|
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
import os
|
|
17
17
|
import time
|
|
18
18
|
import json
|
|
19
|
-
from msprobe.core.common.file_utils import (
|
|
20
|
-
check_file_or_directory_path)
|
|
19
|
+
from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker,
|
|
20
|
+
check_file_or_directory_path, load_json)
|
|
21
21
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
22
22
|
from msprobe.core.common.utils import CompareException
|
|
23
23
|
from msprobe.core.overflow_check.checker import AnomalyDetector
|
|
@@ -159,7 +159,7 @@ def _compare_graph_steps(input_param, args):
|
|
|
159
159
|
bench_steps = sorted(check_and_return_dir_contents(dump_step_b, Const.STEP))
|
|
160
160
|
|
|
161
161
|
if npu_steps != bench_steps:
|
|
162
|
-
logger.error('The number of steps in the two runs
|
|
162
|
+
logger.error('The number of steps in the two runs is different. Unable to match the steps.')
|
|
163
163
|
raise CompareException(CompareException.INVALID_PATH_ERROR)
|
|
164
164
|
|
|
165
165
|
for folder_step in npu_steps:
|
|
@@ -220,8 +220,7 @@ def _graph_service_parser(parser):
|
|
|
220
220
|
|
|
221
221
|
|
|
222
222
|
def _graph_service_command(args):
|
|
223
|
-
|
|
224
|
-
input_param = json.load(file)
|
|
223
|
+
input_param = load_json(args.input_path)
|
|
225
224
|
npu_path = input_param.get("npu_path")
|
|
226
225
|
bench_path = input_param.get("bench_path")
|
|
227
226
|
check_file_or_directory_path(npu_path, isdir=True)
|
msprobe/visualization/utils.py
CHANGED
|
@@ -42,14 +42,6 @@ def load_data_json_file(file_path):
|
|
|
42
42
|
return load_json_file(file_path).get(GraphConst.DATA_KEY, {})
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
def save_json_file(file_path, data):
|
|
46
|
-
"""
|
|
47
|
-
保存json文件
|
|
48
|
-
"""
|
|
49
|
-
with FileOpen(file_path, 'w') as f:
|
|
50
|
-
f.write(json.dumps(data, indent=4))
|
|
51
|
-
|
|
52
|
-
|
|
53
45
|
def get_csv_df(stack_mode, csv_data, compare_mode):
|
|
54
46
|
"""
|
|
55
47
|
调用acc接口写入csv
|
|
@@ -73,14 +65,6 @@ def str2float(percentage_str):
|
|
|
73
65
|
return 0
|
|
74
66
|
|
|
75
67
|
|
|
76
|
-
def is_integer(s):
|
|
77
|
-
try:
|
|
78
|
-
int(s)
|
|
79
|
-
return True
|
|
80
|
-
except Exception:
|
|
81
|
-
return False
|
|
82
|
-
|
|
83
|
-
|
|
84
68
|
def check_directory_content(input_path):
|
|
85
69
|
"""
|
|
86
70
|
检查input_path内容, 是否全是step{数字}命名的文件夹(例如step0), 或者全是rank{数字}命名的文件夹(例如rank0), 或者全是文件
|
|
@@ -143,18 +127,17 @@ class ToolTip:
|
|
|
143
127
|
'当最大相对误差越接近0表示其计算的误差越小。'
|
|
144
128
|
'当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象'
|
|
145
129
|
)
|
|
146
|
-
SMALL_VALUE_TIP = '{}, 由于{}小于{}, 建议不参考此相对误差,请参考绝对误差'
|
|
147
130
|
|
|
148
131
|
|
|
149
132
|
class GraphConst:
|
|
150
133
|
CONSTRUCT_FILE = 'construct.json'
|
|
151
134
|
DUMP_FILE = 'dump.json'
|
|
152
135
|
STACK_FILE = 'stack.json'
|
|
153
|
-
GRAPH_FILE = 'graph.vis'
|
|
154
136
|
ERROR_KEY = 'error_key'
|
|
155
137
|
SUMMARY_COMPARE = 0
|
|
156
138
|
MD5_COMPARE = 1
|
|
157
139
|
REAL_DATA_COMPARE = 2
|
|
140
|
+
STRUCTURE_COMPARE = 3
|
|
158
141
|
JSON_NPU_KEY = 'NPU'
|
|
159
142
|
JSON_BENCH_KEY = 'Bench'
|
|
160
143
|
JSON_TIP_KEY = 'ToolTip'
|
|
@@ -163,35 +146,22 @@ class GraphConst:
|
|
|
163
146
|
JSON_DATA_KEY = 'dump_data_dir'
|
|
164
147
|
JSON_TASK_KEY = 'task'
|
|
165
148
|
DATA_KEY = 'data'
|
|
166
|
-
REAL_DATA_TH = 0.1
|
|
167
|
-
MAX_RELATIVE_ERR_TH = 0.5
|
|
168
149
|
ROUND_TH = 6
|
|
169
150
|
JSON_INDEX_KEY = 'precision_index'
|
|
170
151
|
MATCHED_DISTRIBUTED = 'matched_distributed'
|
|
171
152
|
OVERFLOW_LEVEL = 'overflow_level'
|
|
172
153
|
MAX_INDEX_KEY = 1
|
|
173
154
|
MIN_INDEX_KEY = 0
|
|
174
|
-
SUGGEST_KEY = 'text'
|
|
175
|
-
TAG_NA = 'na'
|
|
176
|
-
OUTPUT_INDEX_TWO = -2
|
|
177
|
-
OUTPUT_INDEX_THREE = -3
|
|
178
|
-
OUTPUT_MIN_LEN = 3
|
|
179
155
|
INPUT = '.input.'
|
|
180
156
|
OUTPUT = '.output.'
|
|
181
157
|
STR_MAX_LEN = 50
|
|
182
|
-
SMALL_VALUE = 1e-3
|
|
183
158
|
MD5_INDEX_LIST = [CompareConst.RESULT]
|
|
184
|
-
REAL_DATA_INDEX_LIST =
|
|
185
|
-
|
|
186
|
-
SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF,
|
|
187
|
-
CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR,
|
|
188
|
-
CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]
|
|
189
|
-
VALUE_INDEX_LIST = [Const.MAX, Const.MIN, Const.MEAN, Const.NORM]
|
|
159
|
+
REAL_DATA_INDEX_LIST = CompareConst.ALL_COMPARE_INDEX
|
|
160
|
+
SUMMARY_INDEX_LIST = CompareConst.SUMMARY_COMPARE_INDEX
|
|
190
161
|
APIS_BETWEEN_MODULES = 'Apis_Between_Modules'
|
|
191
162
|
NULL = 'null'
|
|
192
163
|
NONE = 'None'
|
|
193
164
|
VALUE = 'value'
|
|
194
|
-
BRACE = '{}'
|
|
195
165
|
DESCRIPTION = 'description'
|
|
196
166
|
COLORS = 'Colors'
|
|
197
167
|
MICRO_STEPS = 'MicroSteps'
|
|
@@ -200,13 +170,15 @@ class GraphConst:
|
|
|
200
170
|
DUMP_MODE_TO_GRAPHCOMPARE_MODE_MAPPING = {
|
|
201
171
|
Const.ALL: REAL_DATA_COMPARE,
|
|
202
172
|
Const.SUMMARY: SUMMARY_COMPARE,
|
|
203
|
-
Const.MD5: MD5_COMPARE
|
|
173
|
+
Const.MD5: MD5_COMPARE,
|
|
174
|
+
Const.STRUCTURE: STRUCTURE_COMPARE
|
|
204
175
|
}
|
|
205
176
|
|
|
206
177
|
GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING = {
|
|
207
178
|
REAL_DATA_COMPARE: Const.ALL,
|
|
208
179
|
SUMMARY_COMPARE: Const.SUMMARY,
|
|
209
|
-
MD5_COMPARE: Const.MD5
|
|
180
|
+
MD5_COMPARE: Const.MD5,
|
|
181
|
+
STRUCTURE_COMPARE: Const.STRUCTURE
|
|
210
182
|
}
|
|
211
183
|
|
|
212
184
|
RANKS = 'ranks'
|
|
@@ -215,3 +187,8 @@ class GraphConst:
|
|
|
215
187
|
|
|
216
188
|
SRC = 'src'
|
|
217
189
|
DST = 'dst'
|
|
190
|
+
|
|
191
|
+
BATCH_P2P = 'batch_isend_irecv'
|
|
192
|
+
OP = 'op'
|
|
193
|
+
PEER = 'peer'
|
|
194
|
+
GROUP_ID = 'group_id'
|