mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/RECORD +196 -141
- msprobe/CMakeLists.txt +5 -0
- msprobe/README.md +14 -19
- msprobe/config.json +1 -0
- msprobe/core/common/const.py +155 -6
- msprobe/core/common/exceptions.py +3 -1
- msprobe/core/common/file_utils.py +33 -7
- msprobe/core/common/inplace_ops.yaml +3 -0
- msprobe/core/common/utils.py +28 -14
- msprobe/core/common_config.py +6 -0
- msprobe/core/compare/acc_compare.py +139 -128
- msprobe/core/compare/check.py +31 -29
- msprobe/core/compare/compare_cli.py +17 -16
- msprobe/core/compare/highlight.py +186 -99
- msprobe/core/compare/layer_mapping/data_scope_parser.py +18 -7
- msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
- msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
- msprobe/core/compare/merge_result/merge_result.py +380 -0
- msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
- msprobe/core/compare/multiprocessing_compute.py +2 -2
- msprobe/core/compare/npy_compare.py +109 -147
- msprobe/core/compare/utils.py +189 -69
- msprobe/core/data_dump/data_collector.py +51 -21
- msprobe/core/data_dump/data_processor/base.py +38 -20
- msprobe/core/data_dump/data_processor/factory.py +5 -3
- msprobe/core/data_dump/data_processor/mindspore_processor.py +154 -20
- msprobe/core/data_dump/data_processor/pytorch_processor.py +118 -58
- msprobe/core/data_dump/json_writer.py +29 -1
- msprobe/core/data_dump/scope.py +19 -18
- msprobe/core/overflow_check/abnormal_scene.py +9 -5
- msprobe/core/overflow_check/checker.py +1 -1
- msprobe/core/overflow_check/utils.py +1 -1
- msprobe/docs/01.installation.md +96 -17
- msprobe/docs/02.config_introduction.md +5 -5
- msprobe/docs/05.data_dump_PyTorch.md +91 -61
- msprobe/docs/06.data_dump_MindSpore.md +57 -19
- msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -4
- msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
- msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +120 -27
- msprobe/docs/21.visualization_PyTorch.md +115 -35
- msprobe/docs/22.visualization_MindSpore.md +138 -41
- msprobe/docs/23.generate_operator_PyTorch.md +107 -0
- msprobe/docs/24.code_mapping_Mindspore.md +28 -0
- msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
- msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
- msprobe/docs/27.dump_json_instruction.md +521 -0
- msprobe/docs/FAQ.md +26 -2
- msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
- msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
- msprobe/docs/img/visualization/tensorboard_1.png +0 -0
- msprobe/docs/img/visualization/tensorboard_2.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_browser_2.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/GPTModel.png +0 -0
- msprobe/docs/visualization/ParallelMLP.png +0 -0
- msprobe/docs/visualization/layer_mapping_example.md +132 -0
- msprobe/docs/visualization/mapping.png +0 -0
- msprobe/docs/visualization/mapping1.png +0 -0
- msprobe/docs/visualization/module_name.png +0 -0
- msprobe/docs/visualization/module_name1.png +0 -0
- msprobe/docs/visualization/no_mapping.png +0 -0
- msprobe/docs/visualization/no_mapping1.png +0 -0
- msprobe/docs/visualization/no_mapping_analyze.png +0 -0
- msprobe/docs/visualization/top_layer.png +0 -0
- msprobe/mindspore/__init__.py +10 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +57 -25
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +5 -7
- msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
- msprobe/mindspore/api_accuracy_checker/main.py +1 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
- msprobe/mindspore/code_mapping/bind.py +264 -0
- msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
- msprobe/mindspore/code_mapping/graph.py +49 -0
- msprobe/mindspore/code_mapping/graph_parser.py +226 -0
- msprobe/mindspore/code_mapping/main.py +24 -0
- msprobe/mindspore/code_mapping/processor.py +34 -0
- msprobe/mindspore/common/const.py +3 -1
- msprobe/mindspore/common/utils.py +50 -5
- msprobe/mindspore/compare/distributed_compare.py +0 -2
- msprobe/mindspore/compare/ms_compare.py +105 -63
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/debugger/debugger_config.py +3 -0
- msprobe/mindspore/debugger/precision_debugger.py +81 -12
- msprobe/mindspore/dump/hook_cell/api_registry.py +83 -16
- msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
- msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
- msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
- msprobe/mindspore/grad_probe/hook.py +13 -4
- msprobe/mindspore/mindtorch/__init__.py +18 -0
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
- msprobe/mindspore/ms_config.py +5 -1
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
- msprobe/mindspore/service.py +267 -101
- msprobe/msprobe.py +24 -3
- msprobe/pytorch/__init__.py +7 -6
- msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
- msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
- msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +54 -30
- msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
- msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
- msprobe/pytorch/common/parse_json.py +2 -1
- msprobe/pytorch/common/utils.py +45 -2
- msprobe/pytorch/compare/distributed_compare.py +17 -29
- msprobe/pytorch/compare/pt_compare.py +40 -20
- msprobe/pytorch/debugger/debugger_config.py +27 -12
- msprobe/pytorch/debugger/precision_debugger.py +42 -12
- msprobe/pytorch/dump/module_dump/__init__.py +0 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
- msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +80 -6
- msprobe/pytorch/free_benchmark/common/params.py +2 -1
- msprobe/pytorch/free_benchmark/common/utils.py +3 -0
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
- msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
- msprobe/pytorch/hook_module/__init__.py +1 -1
- msprobe/pytorch/hook_module/hook_module.py +14 -11
- msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +34 -0
- msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
- msprobe/pytorch/hook_module/wrap_functional.py +0 -40
- msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
- msprobe/pytorch/monitor/anomaly_detect.py +107 -22
- msprobe/pytorch/monitor/csv2tb.py +166 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
- msprobe/pytorch/monitor/features.py +3 -3
- msprobe/pytorch/monitor/module_hook.py +483 -277
- msprobe/pytorch/monitor/module_metric.py +27 -48
- msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
- msprobe/pytorch/monitor/optimizer_collect.py +52 -14
- msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
- msprobe/pytorch/monitor/utils.py +77 -6
- msprobe/pytorch/online_dispatch/dispatch.py +8 -2
- msprobe/pytorch/parse_tool/lib/compare.py +10 -10
- msprobe/pytorch/parse_tool/lib/config.py +5 -7
- msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
- msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
- msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
- msprobe/pytorch/parse_tool/lib/utils.py +18 -19
- msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
- msprobe/pytorch/service.py +176 -106
- msprobe/visualization/builder/graph_builder.py +62 -5
- msprobe/visualization/builder/msprobe_adapter.py +24 -2
- msprobe/visualization/compare/graph_comparator.py +64 -14
- msprobe/visualization/compare/mode_adapter.py +1 -15
- msprobe/visualization/graph/base_node.py +12 -17
- msprobe/visualization/graph/distributed_analyzer.py +318 -0
- msprobe/visualization/graph/graph.py +9 -0
- msprobe/visualization/graph_service.py +97 -23
- msprobe/visualization/utils.py +14 -29
- msprobe/pytorch/functional/module_dump.py +0 -84
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
- /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,19 +13,23 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
import math
|
|
17
16
|
import abc
|
|
17
|
+
import math
|
|
18
|
+
import multiprocessing
|
|
18
19
|
import re
|
|
19
20
|
from collections import namedtuple
|
|
21
|
+
|
|
20
22
|
import numpy as np
|
|
21
23
|
import openpyxl
|
|
22
24
|
from openpyxl.styles import PatternFill
|
|
25
|
+
from openpyxl.utils.dataframe import dataframe_to_rows
|
|
23
26
|
from tqdm import tqdm
|
|
24
|
-
|
|
27
|
+
|
|
28
|
+
from msprobe.core.common.const import CompareConst, Const
|
|
25
29
|
from msprobe.core.common.file_utils import save_workbook
|
|
26
30
|
from msprobe.core.common.log import logger
|
|
27
|
-
from msprobe.core.common.
|
|
28
|
-
from msprobe.core.
|
|
31
|
+
from msprobe.core.common.utils import get_header_index, safe_get_value
|
|
32
|
+
from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
class HighlightCheck(abc.ABC):
|
|
@@ -44,6 +48,7 @@ def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
|
44
48
|
|
|
45
49
|
class CheckOrderMagnitude(HighlightCheck):
|
|
46
50
|
"""检查Max diff的数量级差异"""
|
|
51
|
+
|
|
47
52
|
def apply(self, info, color_columns, dump_mode):
|
|
48
53
|
api_in, api_out, num = info
|
|
49
54
|
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
@@ -54,12 +59,13 @@ class CheckOrderMagnitude(HighlightCheck):
|
|
|
54
59
|
out_order = 0 if abs(api_out[max_diff_index]) < 1 else math.log10(abs(api_out[max_diff_index]))
|
|
55
60
|
if out_order - in_order >= CompareConst.ORDER_MAGNITUDE_DIFF_YELLOW:
|
|
56
61
|
add_highlight_row_info(color_columns.yellow, num,
|
|
57
|
-
"maximum absolute error of both input and output exceed 1, "
|
|
62
|
+
"maximum absolute error of both input/parameters and output exceed 1, "
|
|
58
63
|
"with the output larger by an order of magnitude")
|
|
59
64
|
|
|
60
65
|
|
|
61
66
|
class CheckOneThousandErrorRatio(HighlightCheck):
|
|
62
67
|
"""检查千分误差比率"""
|
|
68
|
+
|
|
63
69
|
def apply(self, info, color_columns, dump_mode):
|
|
64
70
|
api_in, api_out, num = info
|
|
65
71
|
one_thousand_index = get_header_index(CompareConst.ONE_THOUSANDTH_ERR_RATIO, dump_mode)
|
|
@@ -69,15 +75,17 @@ class CheckOneThousandErrorRatio(HighlightCheck):
|
|
|
69
75
|
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
70
76
|
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
71
77
|
add_highlight_row_info(color_columns.red, num,
|
|
72
|
-
"The input's one thousandth err ratio exceeds 0.9,
|
|
78
|
+
"The input/parameters's one thousandth err ratio exceeds 0.9, "
|
|
79
|
+
"while the output's is below 0.6")
|
|
73
80
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
74
81
|
add_highlight_row_info(color_columns.yellow, num,
|
|
75
82
|
"The output's one thousandth err ratio decreases by more than 0.1 "
|
|
76
|
-
"compared to the input's")
|
|
83
|
+
"compared to the input/parameters's")
|
|
77
84
|
|
|
78
85
|
|
|
79
86
|
class CheckCosineSimilarity(HighlightCheck):
|
|
80
87
|
"""检查余弦相似度"""
|
|
88
|
+
|
|
81
89
|
def apply(self, info, color_columns, dump_mode):
|
|
82
90
|
api_in, api_out, num = info
|
|
83
91
|
cosine_index = get_header_index(CompareConst.COSINE, dump_mode)
|
|
@@ -85,17 +93,21 @@ class CheckCosineSimilarity(HighlightCheck):
|
|
|
85
93
|
return
|
|
86
94
|
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
87
95
|
add_highlight_row_info(color_columns.yellow, num,
|
|
88
|
-
"The output's cosine decreases by more than 0.1
|
|
96
|
+
"The output's cosine decreases by more than 0.1 "
|
|
97
|
+
"compared to the input/parameters's")
|
|
89
98
|
|
|
90
99
|
|
|
91
100
|
class CheckMaxRelativeDiff(HighlightCheck):
|
|
92
101
|
"""检查最大相对差异"""
|
|
102
|
+
|
|
93
103
|
def apply(self, info, color_columns, dump_mode):
|
|
94
104
|
api_in, api_out, num = info
|
|
95
105
|
max_diff_index = get_header_index(CompareConst.MAX_DIFF, dump_mode)
|
|
96
106
|
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
97
|
-
input_max_relative_diff = np.abs(
|
|
98
|
-
|
|
107
|
+
input_max_relative_diff = np.abs(
|
|
108
|
+
np.divide(api_in[max_diff_index], max(Const.FLOAT_EPSILON, api_in[bench_max_index])))
|
|
109
|
+
output_max_relative_diff = np.abs(
|
|
110
|
+
np.divide(api_out[max_diff_index], max(Const.FLOAT_EPSILON, api_out[bench_max_index])))
|
|
99
111
|
if not isinstance(input_max_relative_diff, (float, int)) or not isinstance(output_max_relative_diff,
|
|
100
112
|
(float, int)):
|
|
101
113
|
return
|
|
@@ -104,11 +116,13 @@ class CheckMaxRelativeDiff(HighlightCheck):
|
|
|
104
116
|
elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
|
|
105
117
|
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
106
118
|
add_highlight_row_info(color_columns.yellow, num,
|
|
107
|
-
"The output's maximum relative error exceeds 0.1,
|
|
119
|
+
"The output's maximum relative error exceeds 0.1, "
|
|
120
|
+
"while the input/parameters's is below 0.01")
|
|
108
121
|
|
|
109
122
|
|
|
110
123
|
class CheckOverflow(HighlightCheck):
|
|
111
124
|
"""检查是否存在溢出"""
|
|
125
|
+
|
|
112
126
|
def apply(self, info, color_columns, dump_mode):
|
|
113
127
|
line, num = info
|
|
114
128
|
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
@@ -141,9 +155,24 @@ class HighlightRules:
|
|
|
141
155
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
142
156
|
"check_max_relative_diff": CheckMaxRelativeDiff(),
|
|
143
157
|
}
|
|
144
|
-
|
|
145
158
|
|
|
146
|
-
|
|
159
|
+
|
|
160
|
+
def check_indices_numeric(api_items, indices: list):
|
|
161
|
+
"""检查指定索引处的值是否都为数字类型(int 或 float)"""
|
|
162
|
+
return all(isinstance(api_items[i], (float, int)) for i in indices)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def apply_comparison_rules(api_info, dump_mode, color_columns):
|
|
166
|
+
"""output与input/params的比较"""
|
|
167
|
+
if dump_mode == Const.SUMMARY:
|
|
168
|
+
for rule in HighlightRules.summary_compare_rules.values():
|
|
169
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
170
|
+
else:
|
|
171
|
+
for rule in HighlightRules.compare_rules.values():
|
|
172
|
+
rule.apply(api_info, color_columns, dump_mode)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def find_error_rows(result, api_batch, highlight_dict, dump_mode):
|
|
147
176
|
"""找到单个API中需要高亮的行"""
|
|
148
177
|
if dump_mode == Const.MD5:
|
|
149
178
|
return
|
|
@@ -158,35 +187,34 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, dump_mode):
|
|
|
158
187
|
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
159
188
|
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
160
189
|
|
|
190
|
+
api_batch_start = api_batch.start # result_df的input起始全局索引
|
|
191
|
+
api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
|
|
192
|
+
api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
|
|
193
|
+
api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
|
|
194
|
+
api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
|
|
195
|
+
|
|
161
196
|
# 对单行API的输入或输出进行误差判断
|
|
162
197
|
for i, line in enumerate(result):
|
|
163
|
-
|
|
164
|
-
line_info = LineInfo(line_data=line, num_pointer=
|
|
198
|
+
index = api_batch_start + i
|
|
199
|
+
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
165
200
|
for rule in HighlightRules.basic_rules.values():
|
|
166
201
|
rule.apply(line_info, color_columns, dump_mode)
|
|
167
202
|
|
|
168
203
|
# 对API的输出与输入比较,进行误差判断
|
|
169
|
-
for n, api_out in enumerate(result[
|
|
170
|
-
|
|
171
|
-
|
|
204
|
+
for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
|
|
205
|
+
index = api_batch_start + api_batch_params_slice_index_local + n
|
|
206
|
+
# 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
|
|
207
|
+
if index in red_lines:
|
|
172
208
|
continue
|
|
173
|
-
if not
|
|
174
|
-
or not isinstance(api_out[bench_max_index], (float, int)) \
|
|
175
|
-
or not isinstance(api_out[max_diff_index], (float, int)):
|
|
209
|
+
if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
|
|
176
210
|
continue
|
|
177
|
-
for _, api_in in enumerate(result[0:n_num_input]):
|
|
178
|
-
if not isinstance(api_in[npu_max_index], (float, int)) \
|
|
179
|
-
or not isinstance(api_in[bench_max_index], (float, int)) \
|
|
180
|
-
or not isinstance(api_in[max_diff_index], (float, int)):
|
|
181
|
-
continue
|
|
182
211
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
rule.apply(api_info, color_columns, dump_mode)
|
|
212
|
+
# input/parameters的比较检查, 这里api_in包括input、parameters
|
|
213
|
+
for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
|
|
214
|
+
if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
|
|
215
|
+
continue
|
|
216
|
+
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
217
|
+
apply_comparison_rules(api_info, dump_mode, color_columns)
|
|
190
218
|
|
|
191
219
|
red_lines_num_set = {x[0] for x in red_lines}
|
|
192
220
|
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
@@ -196,78 +224,148 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, dump_mode):
|
|
|
196
224
|
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
197
225
|
|
|
198
226
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
227
|
+
class ApiBatch:
|
|
228
|
+
def __init__(self, api_name: str, start: int):
|
|
229
|
+
self.api_name = api_name
|
|
230
|
+
self.start = start
|
|
231
|
+
self.input_len = 1 # input的数量
|
|
232
|
+
self.params_end_index = start + 1 # params的结束index
|
|
233
|
+
self.output_end_index = start + 1 # output的结束index
|
|
234
|
+
self.params_grad_end_index = start + 1 # params_grad的结束index
|
|
235
|
+
# 内部state的标志("input", "output", "parameters", "parameters_grad"),
|
|
236
|
+
# 用于控制计算input_len, output_end_index, params_end_index, self.params_grad_end_index
|
|
237
|
+
self._state = Const.INPUT # api_batch初始化为input
|
|
238
|
+
|
|
239
|
+
def set_state(self, state: str):
|
|
240
|
+
"""设置当前状态"""
|
|
241
|
+
if state in {Const.INPUT, Const.OUTPUT, Const.KWARGS, Const.PARAMS, Const.PARAMS_GRAD}:
|
|
242
|
+
self._state = state
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError(f"Invalid state: {state}")
|
|
245
|
+
|
|
246
|
+
def increment(self, state: str):
|
|
247
|
+
self.set_state(state)
|
|
248
|
+
if self._state == Const.INPUT or self._state == Const.KWARGS:
|
|
249
|
+
self.input_len += 1
|
|
250
|
+
self.params_end_index += 1
|
|
251
|
+
self.output_end_index += 1
|
|
252
|
+
if self._state == Const.PARAMS:
|
|
253
|
+
self.params_end_index += 1
|
|
254
|
+
self.output_end_index += 1
|
|
255
|
+
if self._state == Const.OUTPUT:
|
|
256
|
+
self.output_end_index += 1
|
|
257
|
+
self.params_grad_end_index += 1
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def api_batches_update(api_batches, api_name, state, index):
|
|
261
|
+
"""
|
|
262
|
+
当一个api的所有item更新完后,input, output的索引范围:
|
|
263
|
+
input: [start: start+input_len]
|
|
264
|
+
output: [start+input_len: output_end_index]
|
|
265
|
+
params: [output_end_index: params_end_index]
|
|
266
|
+
"""
|
|
267
|
+
if not api_batches:
|
|
268
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
204
269
|
else:
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
270
|
+
api_batch = api_batches[-1]
|
|
271
|
+
if api_batch.api_name == api_name or (
|
|
272
|
+
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
273
|
+
try:
|
|
274
|
+
api_batch.increment(state)
|
|
275
|
+
except ValueError as e:
|
|
276
|
+
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
277
|
+
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
278
|
+
else:
|
|
279
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
208
280
|
|
|
209
281
|
|
|
210
282
|
def find_compare_result_error_rows(result_df, highlight_dict, dump_mode):
|
|
211
283
|
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
212
284
|
result = result_df.values
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
num, last_len = 0, 0
|
|
216
|
-
progress_bar = tqdm(total=len(result), desc="API/Module Analyse Progress", unit="item", ncols=100)
|
|
217
|
-
for res_i in result:
|
|
285
|
+
api_batches = []
|
|
286
|
+
for i, res_i in enumerate(result):
|
|
218
287
|
api_full_name = safe_get_value(res_i, 0, "res_i")
|
|
219
288
|
api_name, state = get_name_and_state(api_full_name)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
input_num, output_num = 1, 0
|
|
289
|
+
api_batches_update(api_batches, api_name, state, i)
|
|
290
|
+
with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
|
|
291
|
+
for api_batch in api_batches:
|
|
292
|
+
find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
|
|
293
|
+
dump_mode)
|
|
294
|
+
progress_bar.update(1)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def value_check(value, api_name=None, i=None, result_df_columns=None):
|
|
298
|
+
if not table_value_is_valid(value):
|
|
299
|
+
if result_df_columns:
|
|
300
|
+
logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
|
|
301
|
+
f"is not allowed to be written into the compare result xlsx.")
|
|
234
302
|
else:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
303
|
+
logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def df_malicious_value_check(df_chunk, result_df_columns):
|
|
307
|
+
for row in df_chunk.itertuples(index=False):
|
|
308
|
+
api_name = row[0]
|
|
309
|
+
for i, value in enumerate(row):
|
|
310
|
+
value_check(value, api_name, i, result_df_columns)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def handle_multi_process_malicious_value_check(func, result_df):
|
|
314
|
+
result_total_nums = len(result_df)
|
|
315
|
+
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
316
|
+
|
|
317
|
+
if result_total_nums <= process_num:
|
|
318
|
+
process_num = 1
|
|
319
|
+
chunks = [result_df]
|
|
320
|
+
else:
|
|
321
|
+
chunk_size = result_total_nums // process_num
|
|
322
|
+
chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
|
|
323
|
+
|
|
324
|
+
pool = multiprocessing.Pool(process_num)
|
|
325
|
+
|
|
326
|
+
def err_call(args):
|
|
327
|
+
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
328
|
+
try:
|
|
329
|
+
pool.terminate()
|
|
330
|
+
except OSError:
|
|
331
|
+
logger.error("Pool terminate failed")
|
|
332
|
+
|
|
333
|
+
result_df_columns = result_df.columns.tolist()
|
|
334
|
+
for column in result_df_columns:
|
|
335
|
+
value_check(column)
|
|
336
|
+
for df_chunk in chunks:
|
|
337
|
+
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
338
|
+
|
|
339
|
+
pool.close()
|
|
340
|
+
pool.join()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def compare_result_df_convert(value):
|
|
344
|
+
if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
|
|
345
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
|
|
346
|
+
if isinstance(value, float):
|
|
347
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
|
|
348
|
+
return value
|
|
245
349
|
|
|
246
350
|
|
|
247
351
|
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
248
352
|
"""Write and highlight results in Excel"""
|
|
249
353
|
|
|
250
|
-
update_highlight_err_msg(result_df, highlight_dict)
|
|
354
|
+
update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
251
355
|
|
|
252
356
|
wb = openpyxl.Workbook()
|
|
253
357
|
ws = wb.active
|
|
254
358
|
|
|
255
359
|
# write header
|
|
256
360
|
logger.info('Initializing Excel file.')
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
for
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
|
|
266
|
-
if not csv_value_is_valid(value):
|
|
267
|
-
raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: "
|
|
268
|
-
f"{file_path}.")
|
|
269
|
-
ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
|
|
270
|
-
|
|
361
|
+
|
|
362
|
+
handle_multi_process_malicious_value_check(df_malicious_value_check, result_df)
|
|
363
|
+
|
|
364
|
+
result_df_convert = result_df.applymap(compare_result_df_convert)
|
|
365
|
+
|
|
366
|
+
for row in dataframe_to_rows(result_df_convert, index=False, header=True):
|
|
367
|
+
ws.append(row)
|
|
368
|
+
|
|
271
369
|
# 对可疑数据标色
|
|
272
370
|
logger.info('Coloring Excel in progress.')
|
|
273
371
|
col_len = len(result_df.columns)
|
|
@@ -279,10 +377,11 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
|
279
377
|
)
|
|
280
378
|
for i in highlight_dict.get("red_rows", []):
|
|
281
379
|
for j in range(1, col_len + 1):
|
|
282
|
-
ws.cell(row=i + 2, column=j).fill = red_fill
|
|
380
|
+
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
283
381
|
for i in highlight_dict.get("yellow_rows", []):
|
|
284
382
|
for j in range(1, col_len + 1):
|
|
285
383
|
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
384
|
+
|
|
286
385
|
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
287
386
|
save_workbook(wb, file_path)
|
|
288
387
|
|
|
@@ -314,15 +413,3 @@ def update_highlight_err_msg(result_df, highlight_dict):
|
|
|
314
413
|
red_lines_num_set.add(line_index)
|
|
315
414
|
|
|
316
415
|
result_df[CompareConst.ERROR_MESSAGE] = err_msg
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
def csv_value_is_valid(value: str) -> bool:
|
|
320
|
-
if not isinstance(value, str):
|
|
321
|
-
return True
|
|
322
|
-
try:
|
|
323
|
-
# -1.00 or +1.00 should be consdiered as digit numbers
|
|
324
|
-
float(value)
|
|
325
|
-
except ValueError:
|
|
326
|
-
# otherwise, they will be considered as formular injections
|
|
327
|
-
return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
|
|
328
|
-
return True
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -42,10 +42,10 @@ class DumpDataItem:
|
|
|
42
42
|
construct_scope: str = ""
|
|
43
43
|
scope_direction: Optional[str] = None
|
|
44
44
|
scope_id: Optional[int] = None
|
|
45
|
+
state: str = ""
|
|
45
46
|
|
|
46
47
|
# 类变量使用 ClassVar
|
|
47
|
-
|
|
48
|
-
Const.MS_FRAMEWORK: Const.CELL, Const.PT_FRAMEWORK: Const.MODULE}
|
|
48
|
+
layernames: ClassVar[set] = {Const.CELL, Const.MODULE}
|
|
49
49
|
framework2stack_sign: ClassVar[Dict[str, Tuple[str, str]]] = {
|
|
50
50
|
Const.MS_FRAMEWORK: ("Template", "construct"),
|
|
51
51
|
Const.PT_FRAMEWORK: ("Template", r"in (for|back)ward,")
|
|
@@ -79,19 +79,30 @@ class DumpDataItem:
|
|
|
79
79
|
)
|
|
80
80
|
raise CompareException(CompareException.INVALID_DATA_ERROR)
|
|
81
81
|
|
|
82
|
+
if data_name_list[Const.LAST_INDEX] == Const.PARAMS_GRAD:
|
|
83
|
+
self.api_type = Const.PARAMS_GRAD
|
|
84
|
+
self.api_name = data_name_list[Const.PARAMS_GRAD_NAME_INDEX]
|
|
85
|
+
self.type_name = data_name_list[Const.PARAMS_GRAD_TYPE_NAME_INDEX]
|
|
86
|
+
self.state = Const.PARAMS_GRAD
|
|
87
|
+
return
|
|
88
|
+
|
|
82
89
|
self.api_type = data_name_list[Const.API_TYPE_INDEX]
|
|
83
90
|
self.type_name = data_name_list[Const.TYPE_NAME_INDEX]
|
|
84
|
-
if self.api_type
|
|
91
|
+
if self.api_type in self.layernames:
|
|
85
92
|
self.api_name = data_name_list[Const.LAYER_NAME_INDEX]
|
|
93
|
+
self.state = data_name_list[Const.SCOPE_DIRECTION_INDEX]
|
|
86
94
|
else:
|
|
87
95
|
self.api_name = self.type_name
|
|
96
|
+
self.state = data_name_list[Const.LAST_INDEX]
|
|
88
97
|
|
|
89
98
|
def set_layer_scope(self, construct_info: str) -> None:
|
|
90
99
|
self.construct_scope = construct_info
|
|
91
|
-
if self.api_type
|
|
100
|
+
if self.api_type in self.layernames:
|
|
92
101
|
# remove api name
|
|
93
102
|
data_list = self.data_name.split(Const.SEP)
|
|
94
103
|
data_list = data_list[:Const.LAYER_NAME_INDEX] + data_list[Const.TYPE_NAME_INDEX:]
|
|
104
|
+
elif self.api_type == Const.PARAMS_GRAD:
|
|
105
|
+
data_list = self.data_name.split(Const.SEP)
|
|
95
106
|
elif construct_info:
|
|
96
107
|
data_list = construct_info.split(Const.SEP)
|
|
97
108
|
else:
|
|
@@ -100,7 +111,7 @@ class DumpDataItem:
|
|
|
100
111
|
if data_list:
|
|
101
112
|
self.layer_scope = Const.SEP.join(data_list[:Const.TYPE_NAME_INDEX])
|
|
102
113
|
else:
|
|
103
|
-
self.layer_scope =
|
|
114
|
+
self.layer_scope = Const.TOP_LAYER
|
|
104
115
|
if construct_info:
|
|
105
116
|
construct_list = construct_info.split(Const.SEP)
|
|
106
117
|
if len(construct_list) < abs(Const.LAYER_NAME_INDEX):
|
|
@@ -115,7 +126,7 @@ class DumpDataItem:
|
|
|
115
126
|
|
|
116
127
|
def set_stack_scope(self, stack_info: str) -> None:
|
|
117
128
|
# Cell/Module has no stack info
|
|
118
|
-
if self.api_type
|
|
129
|
+
if self.api_type in self.layernames:
|
|
119
130
|
return
|
|
120
131
|
|
|
121
132
|
if self.api_type in Const.DATA_TYPE_SKIP_LIST or not stack_info:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
|
+
from collections import defaultdict
|
|
17
18
|
|
|
18
19
|
from msprobe.core.common.const import CompareConst, Const
|
|
19
20
|
from msprobe.core.common.file_utils import load_json, load_yaml, save_yaml
|
|
@@ -21,18 +22,20 @@ from msprobe.core.common.utils import (add_time_with_yaml,
|
|
|
21
22
|
detect_framework_by_dump_json,
|
|
22
23
|
get_stack_construct_by_dump_json_path)
|
|
23
24
|
from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
|
|
24
|
-
from msprobe.core.compare.utils import read_op
|
|
25
|
+
from msprobe.core.compare.utils import read_op, reorder_op_name_list
|
|
26
|
+
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
class LayerTrie:
|
|
28
30
|
def __init__(self, type_name, framework=None):
|
|
29
31
|
self.type_name = type_name
|
|
30
|
-
self.data_items =
|
|
32
|
+
self.data_items = defaultdict(list)
|
|
31
33
|
self.children = {}
|
|
32
34
|
self.framework = framework
|
|
33
35
|
|
|
34
36
|
def __repr__(self):
|
|
35
|
-
|
|
37
|
+
data_nums = [{k: len(v)} for k, v in self.data_items.items()]
|
|
38
|
+
return f"Layer(type_name={self.type_name}, data_number={data_nums})"
|
|
36
39
|
|
|
37
40
|
def get(self, name):
|
|
38
41
|
return self.children.get(name)
|
|
@@ -46,10 +49,10 @@ class LayerTrie:
|
|
|
46
49
|
if name not in node.children:
|
|
47
50
|
node.children[name] = LayerTrie(name, data_item.framework)
|
|
48
51
|
node = node.children[name]
|
|
49
|
-
node.data_items.append(data_item)
|
|
52
|
+
node.data_items[data_item.state].append(data_item)
|
|
50
53
|
node.type_name = data_item.type_name
|
|
51
54
|
|
|
52
|
-
def query_data(self, scope, index, default_value=None):
|
|
55
|
+
def query_data(self, scope, state, index, default_value=None):
|
|
53
56
|
parts = scope.split(Const.SEP)
|
|
54
57
|
node = self
|
|
55
58
|
scope_name_list = parts[1:]
|
|
@@ -58,9 +61,9 @@ class LayerTrie:
|
|
|
58
61
|
if name not in node.children:
|
|
59
62
|
return default_value
|
|
60
63
|
node = node.children[name]
|
|
61
|
-
if index >= len(node.data_items):
|
|
64
|
+
if index >= len(node.data_items[state]):
|
|
62
65
|
return default_value
|
|
63
|
-
return node.data_items[index]
|
|
66
|
+
return node.data_items[state][index]
|
|
64
67
|
|
|
65
68
|
def save_to_yaml(self, output_path):
|
|
66
69
|
result = {f"{self.type_name} @ {self}": self.convert_to_dict(self)}
|
|
@@ -70,7 +73,7 @@ class LayerTrie:
|
|
|
70
73
|
|
|
71
74
|
def convert_to_dict(self, node):
|
|
72
75
|
result = {}
|
|
73
|
-
result["data_item"] = [
|
|
76
|
+
result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
|
|
74
77
|
for child_key, child_node in node.children.items():
|
|
75
78
|
key = f"{child_key} @ {child_node}"
|
|
76
79
|
result[key] = self.convert_to_dict(child_node)
|
|
@@ -102,10 +105,11 @@ def convert_scope(layer_trie, data_item, mapping=None):
|
|
|
102
105
|
cur_node = child_node
|
|
103
106
|
idx += 1
|
|
104
107
|
index = -1
|
|
105
|
-
|
|
108
|
+
state = data_item.state
|
|
109
|
+
for idx, child in enumerate(cur_node.data_items[state]):
|
|
106
110
|
if data_item.data_name == child.data_name:
|
|
107
111
|
index = idx
|
|
108
|
-
return new_scope, index
|
|
112
|
+
return new_scope, state, index
|
|
109
113
|
|
|
110
114
|
|
|
111
115
|
def get_data_items_and_tree(dump_json_path, output_path):
|
|
@@ -122,8 +126,8 @@ def get_data_items_and_tree(dump_json_path, output_path):
|
|
|
122
126
|
|
|
123
127
|
|
|
124
128
|
def convert_data_item(npu_tree, bench_tree, npu_data_item, mapping):
|
|
125
|
-
new_scope, index = convert_scope(npu_tree, npu_data_item, mapping)
|
|
126
|
-
bench_data_item = bench_tree.query_data(new_scope, index)
|
|
129
|
+
new_scope, state, index = convert_scope(npu_tree, npu_data_item, mapping)
|
|
130
|
+
bench_data_item = bench_tree.query_data(new_scope, state, index)
|
|
127
131
|
return bench_data_item
|
|
128
132
|
|
|
129
133
|
|
|
@@ -223,7 +227,10 @@ def generate_data_mapping(npu_json_path, bench_json_path, api_mapping, output_pa
|
|
|
223
227
|
continue
|
|
224
228
|
npu_full_op_names = read_full_op_names(npu_data, npu_op_name)
|
|
225
229
|
bench_full_op_names = read_full_op_names(bench_data, bench_op_name)
|
|
226
|
-
|
|
230
|
+
npu_full_op_names_reorder = reorder_op_name_list(npu_full_op_names)
|
|
231
|
+
bench_full_op_names_reorder = reorder_op_name_list(bench_full_op_names)
|
|
232
|
+
mapping = generate_op_data_mapping(npu_op_name, npu_full_op_names_reorder,
|
|
233
|
+
bench_op_name, bench_full_op_names_reorder)
|
|
227
234
|
data_mapping.update(mapping)
|
|
228
235
|
if output_path:
|
|
229
236
|
file_name = add_time_with_yaml("data_mapping")
|
|
@@ -29,9 +29,10 @@ def backward_pass(data_items, name2item):
|
|
|
29
29
|
data_name_list = data_item.data_name.split(Const.SEP)
|
|
30
30
|
if not data_name_list:
|
|
31
31
|
continue
|
|
32
|
-
if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX
|
|
33
|
-
data_name_list[Const.SCOPE_DIRECTION_INDEX
|
|
34
|
-
s.replace(Const.BACKWARD, Const.FORWARD)
|
|
32
|
+
if Const.BACKWARD in data_name_list[Const.SCOPE_DIRECTION_INDEX:]:
|
|
33
|
+
data_name_list[Const.SCOPE_DIRECTION_INDEX:] = [
|
|
34
|
+
s.replace(Const.BACKWARD, Const.FORWARD)
|
|
35
|
+
for s in data_name_list[Const.SCOPE_DIRECTION_INDEX:]
|
|
35
36
|
]
|
|
36
37
|
forward_name = Const.SEP.join(data_name_list)
|
|
37
38
|
forward_item = name2item.get(forward_name, None)
|