mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -30,12 +30,7 @@ from msprobe.core.common.file_utils import save_workbook
|
|
|
30
30
|
from msprobe.core.common.log import logger
|
|
31
31
|
from msprobe.core.common.utils import get_header_index, safe_get_value
|
|
32
32
|
from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class HighlightCheck(abc.ABC):
|
|
36
|
-
@abc.abstractmethod
|
|
37
|
-
def apply(self, info, color_columns, dump_mode):
|
|
38
|
-
raise NotImplementedError
|
|
33
|
+
from msprobe.core.compare.config import ModeConfig
|
|
39
34
|
|
|
40
35
|
|
|
41
36
|
def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
@@ -46,6 +41,12 @@ def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
|
46
41
|
color_list.append((num, [highlight_err_msg]))
|
|
47
42
|
|
|
48
43
|
|
|
44
|
+
class HighlightCheck(abc.ABC):
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
def apply(self, info, color_columns, dump_mode):
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
|
|
49
50
|
class CheckOrderMagnitude(HighlightCheck):
|
|
50
51
|
"""检查Max diff的数量级差异"""
|
|
51
52
|
|
|
@@ -75,12 +76,12 @@ class CheckOneThousandErrorRatio(HighlightCheck):
|
|
|
75
76
|
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
76
77
|
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
77
78
|
add_highlight_row_info(color_columns.red, num,
|
|
78
|
-
"The input/
|
|
79
|
+
"The input/parameter's one thousandth err ratio exceeds 0.9, "
|
|
79
80
|
"while the output's is below 0.6")
|
|
80
81
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
81
82
|
add_highlight_row_info(color_columns.yellow, num,
|
|
82
83
|
"The output's one thousandth err ratio decreases by more than 0.1 "
|
|
83
|
-
"compared to the input/
|
|
84
|
+
"compared to the input/parameter's")
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class CheckCosineSimilarity(HighlightCheck):
|
|
@@ -94,7 +95,7 @@ class CheckCosineSimilarity(HighlightCheck):
|
|
|
94
95
|
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
95
96
|
add_highlight_row_info(color_columns.yellow, num,
|
|
96
97
|
"The output's cosine decreases by more than 0.1 "
|
|
97
|
-
"compared to the input/
|
|
98
|
+
"compared to the input/parameter's")
|
|
98
99
|
|
|
99
100
|
|
|
100
101
|
class CheckMaxRelativeDiff(HighlightCheck):
|
|
@@ -117,7 +118,7 @@ class CheckMaxRelativeDiff(HighlightCheck):
|
|
|
117
118
|
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
118
119
|
add_highlight_row_info(color_columns.yellow, num,
|
|
119
120
|
"The output's maximum relative error exceeds 0.1, "
|
|
120
|
-
"while the input/
|
|
121
|
+
"while the input/parameter's is below 0.01")
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
class CheckOverflow(HighlightCheck):
|
|
@@ -159,73 +160,6 @@ class HighlightRules:
|
|
|
159
160
|
}
|
|
160
161
|
|
|
161
162
|
|
|
162
|
-
def check_indices_numeric(api_items, indices: list):
|
|
163
|
-
"""检查指定索引处的值是否都为数字类型(int 或 float)"""
|
|
164
|
-
return all(isinstance(api_items[i], (float, int)) for i in indices)
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def apply_comparison_rules(api_info, dump_mode, color_columns):
|
|
168
|
-
"""output与input/params的比较"""
|
|
169
|
-
if dump_mode == Const.SUMMARY:
|
|
170
|
-
for rule in HighlightRules.summary_compare_rules.values():
|
|
171
|
-
rule.apply(api_info, color_columns, dump_mode)
|
|
172
|
-
else:
|
|
173
|
-
for rule in HighlightRules.compare_rules.values():
|
|
174
|
-
rule.apply(api_info, color_columns, dump_mode)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def find_error_rows(result, api_batch, highlight_dict, dump_mode):
|
|
178
|
-
"""找到单个API中需要高亮的行"""
|
|
179
|
-
if dump_mode == Const.MD5:
|
|
180
|
-
return
|
|
181
|
-
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
182
|
-
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
183
|
-
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
184
|
-
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
185
|
-
|
|
186
|
-
red_lines, yellow_lines = [], []
|
|
187
|
-
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
188
|
-
ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
|
|
189
|
-
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
190
|
-
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
191
|
-
|
|
192
|
-
api_batch_start = api_batch.start # result_df的input起始全局索引
|
|
193
|
-
api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
|
|
194
|
-
api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
|
|
195
|
-
api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
|
|
196
|
-
api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
|
|
197
|
-
|
|
198
|
-
# 对单行API的输入或输出进行误差判断
|
|
199
|
-
for i, line in enumerate(result):
|
|
200
|
-
index = api_batch_start + i
|
|
201
|
-
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
202
|
-
for rule in HighlightRules.basic_rules.values():
|
|
203
|
-
rule.apply(line_info, color_columns, dump_mode)
|
|
204
|
-
|
|
205
|
-
# 对API的输出与输入比较,进行误差判断
|
|
206
|
-
for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
|
|
207
|
-
index = api_batch_start + api_batch_params_slice_index_local + n
|
|
208
|
-
# 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
|
|
209
|
-
if index in red_lines:
|
|
210
|
-
continue
|
|
211
|
-
if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
|
|
212
|
-
continue
|
|
213
|
-
|
|
214
|
-
# input/parameters的比较检查, 这里api_in包括input、parameters
|
|
215
|
-
for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
|
|
216
|
-
if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
|
|
217
|
-
continue
|
|
218
|
-
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
219
|
-
apply_comparison_rules(api_info, dump_mode, color_columns)
|
|
220
|
-
|
|
221
|
-
red_lines_num_set = {x[0] for x in red_lines}
|
|
222
|
-
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
223
|
-
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
224
|
-
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
225
|
-
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
226
|
-
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
227
|
-
|
|
228
|
-
|
|
229
163
|
class ApiBatch:
|
|
230
164
|
def __init__(self, api_name: str, start: int):
|
|
231
165
|
self.api_name = api_name
|
|
@@ -259,159 +193,225 @@ class ApiBatch:
|
|
|
259
193
|
self.params_grad_end_index += 1
|
|
260
194
|
|
|
261
195
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
input: [start: start+input_len]
|
|
266
|
-
output: [start+input_len: output_end_index]
|
|
267
|
-
params: [output_end_index: params_end_index]
|
|
268
|
-
"""
|
|
269
|
-
if not api_batches:
|
|
270
|
-
api_batches.append(ApiBatch(api_name, index))
|
|
271
|
-
else:
|
|
272
|
-
api_batch = api_batches[-1]
|
|
273
|
-
if api_batch.api_name == api_name or (
|
|
274
|
-
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
275
|
-
try:
|
|
276
|
-
api_batch.increment(state)
|
|
277
|
-
except ValueError as e:
|
|
278
|
-
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
279
|
-
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
280
|
-
else:
|
|
281
|
-
api_batches.append(ApiBatch(api_name, index))
|
|
196
|
+
class HighLight:
|
|
197
|
+
def __init__(self, mode_config: ModeConfig):
|
|
198
|
+
self.mode_config = mode_config
|
|
282
199
|
|
|
283
|
-
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
for api_batch in api_batches:
|
|
294
|
-
find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
|
|
295
|
-
dump_mode)
|
|
296
|
-
progress_bar.update(1)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
def value_check(value, api_name=None, i=None, result_df_columns=None):
|
|
300
|
-
if not table_value_is_valid(value):
|
|
301
|
-
if result_df_columns:
|
|
302
|
-
logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
|
|
303
|
-
f"is not allowed to be written into the compare result xlsx.")
|
|
200
|
+
@staticmethod
|
|
201
|
+
def api_batches_update(api_batches, api_name, state, index):
|
|
202
|
+
"""
|
|
203
|
+
当一个api的所有item更新完后,input, output的索引范围:
|
|
204
|
+
input: [start: start+input_len]
|
|
205
|
+
output: [start+input_len: output_end_index]
|
|
206
|
+
params: [output_end_index: params_end_index]
|
|
207
|
+
"""
|
|
208
|
+
if not api_batches:
|
|
209
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
304
210
|
else:
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
pool = multiprocessing.Pool(process_num)
|
|
327
|
-
|
|
328
|
-
def err_call(args):
|
|
329
|
-
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
330
|
-
try:
|
|
331
|
-
pool.terminate()
|
|
332
|
-
except OSError:
|
|
333
|
-
logger.error("Pool terminate failed")
|
|
334
|
-
|
|
335
|
-
result_df_columns = result_df.columns.tolist()
|
|
336
|
-
for column in result_df_columns:
|
|
337
|
-
value_check(column)
|
|
338
|
-
for df_chunk in chunks:
|
|
339
|
-
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
340
|
-
|
|
341
|
-
pool.close()
|
|
342
|
-
pool.join()
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
def compare_result_df_convert(value):
|
|
346
|
-
if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
|
|
347
|
-
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
|
|
348
|
-
if isinstance(value, float):
|
|
349
|
-
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
|
|
350
|
-
return value
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
354
|
-
"""Write and highlight results in Excel"""
|
|
211
|
+
api_batch = api_batches[-1]
|
|
212
|
+
if api_batch.api_name == api_name or (
|
|
213
|
+
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
214
|
+
try:
|
|
215
|
+
api_batch.increment(state)
|
|
216
|
+
except ValueError as e:
|
|
217
|
+
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
218
|
+
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
219
|
+
else:
|
|
220
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def check_indices_numeric(api_items, indices: list):
|
|
224
|
+
"""检查指定索引处的值是否都为数字类型(int 或 float)"""
|
|
225
|
+
return all(isinstance(api_items[i], (float, int)) for i in indices)
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def update_highlight_err_msg(result_df, highlight_dict):
|
|
229
|
+
if result_df.shape[1] <= 1:
|
|
230
|
+
return
|
|
355
231
|
|
|
356
|
-
|
|
232
|
+
if CompareConst.NPU_MD5 in result_df.columns:
|
|
233
|
+
return
|
|
357
234
|
|
|
358
|
-
|
|
359
|
-
|
|
235
|
+
err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
|
|
236
|
+
red_lines_num_set = highlight_dict.get('red_rows')
|
|
237
|
+
|
|
238
|
+
for color in ['red', 'yellow']:
|
|
239
|
+
line_key = f'{color}_lines'
|
|
240
|
+
lines = highlight_dict.get(line_key, [])
|
|
241
|
+
for line_index, messages in lines:
|
|
242
|
+
if color == 'yellow' and line_index in red_lines_num_set:
|
|
243
|
+
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
244
|
+
|
|
245
|
+
for msg in messages:
|
|
246
|
+
if err_msg[line_index] == '':
|
|
247
|
+
err_msg[line_index] = msg
|
|
248
|
+
else:
|
|
249
|
+
err_msg[line_index] += '\n' + msg
|
|
250
|
+
|
|
251
|
+
if color == 'red':
|
|
252
|
+
red_lines_num_set.add(line_index)
|
|
253
|
+
|
|
254
|
+
result_df[CompareConst.ERROR_MESSAGE] = err_msg
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def compare_result_df_convert(value):
|
|
258
|
+
if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
|
|
259
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
|
|
260
|
+
if isinstance(value, float):
|
|
261
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
|
|
262
|
+
return value
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def value_check(value, api_name=None, i=None, result_df_columns=None):
|
|
266
|
+
if not table_value_is_valid(value):
|
|
267
|
+
if result_df_columns:
|
|
268
|
+
logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
|
|
269
|
+
f"is not allowed to be written into the compare result xlsx.")
|
|
270
|
+
else:
|
|
271
|
+
logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
|
|
272
|
+
|
|
273
|
+
def find_compare_result_error_rows(self, result_df, highlight_dict):
|
|
274
|
+
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
275
|
+
result = result_df.values
|
|
276
|
+
api_batches = []
|
|
277
|
+
for i, res_i in enumerate(result):
|
|
278
|
+
api_full_name = safe_get_value(res_i, 0, "res_i")
|
|
279
|
+
api_name, state = get_name_and_state(api_full_name)
|
|
280
|
+
self.api_batches_update(api_batches, api_name, state, i)
|
|
281
|
+
with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
|
|
282
|
+
for api_batch in api_batches:
|
|
283
|
+
self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch,
|
|
284
|
+
highlight_dict)
|
|
285
|
+
progress_bar.update(1)
|
|
286
|
+
|
|
287
|
+
def find_error_rows(self, result, api_batch, highlight_dict):
|
|
288
|
+
"""找到单个API中需要高亮的行"""
|
|
289
|
+
if self.mode_config.dump_mode == Const.MD5:
|
|
290
|
+
return
|
|
291
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, self.mode_config.dump_mode)
|
|
292
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, self.mode_config.dump_mode)
|
|
293
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if self.mode_config.dump_mode == Const.SUMMARY
|
|
294
|
+
else CompareConst.MAX_ABS_ERR, self.mode_config.dump_mode)
|
|
295
|
+
|
|
296
|
+
red_lines, yellow_lines = [], []
|
|
297
|
+
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
298
|
+
ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
|
|
299
|
+
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
300
|
+
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
301
|
+
|
|
302
|
+
api_batch_start = api_batch.start # result_df的input起始全局索引
|
|
303
|
+
api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
|
|
304
|
+
api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
|
|
305
|
+
api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
|
|
306
|
+
api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
|
|
307
|
+
|
|
308
|
+
# 对单行API的输入或输出进行误差判断
|
|
309
|
+
for i, line in enumerate(result):
|
|
310
|
+
index = api_batch_start + i
|
|
311
|
+
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
312
|
+
for rule in HighlightRules.basic_rules.values():
|
|
313
|
+
rule.apply(line_info, color_columns, self.mode_config.dump_mode)
|
|
314
|
+
|
|
315
|
+
# 对API的输出与输入比较,进行误差判断
|
|
316
|
+
for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
|
|
317
|
+
index = api_batch_start + api_batch_params_slice_index_local + n
|
|
318
|
+
# 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
|
|
319
|
+
if index in red_lines:
|
|
320
|
+
continue
|
|
321
|
+
if not self.check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
|
|
322
|
+
continue
|
|
360
323
|
|
|
361
|
-
|
|
362
|
-
|
|
324
|
+
# input/parameters的比较检查, 这里api_in包括input、parameters
|
|
325
|
+
for api_in in result[0: api_batch_params_slice_index_local]:
|
|
326
|
+
if not self.check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
|
|
327
|
+
continue
|
|
328
|
+
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
329
|
+
self.apply_comparison_rules(api_info, color_columns)
|
|
330
|
+
|
|
331
|
+
red_lines_num_set = {x[0] for x in red_lines}
|
|
332
|
+
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
333
|
+
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
334
|
+
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
335
|
+
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
336
|
+
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
337
|
+
|
|
338
|
+
def apply_comparison_rules(self, api_info, color_columns):
|
|
339
|
+
"""output与input/params的比较"""
|
|
340
|
+
if self.mode_config.dump_mode == Const.SUMMARY:
|
|
341
|
+
for rule in HighlightRules.summary_compare_rules.values():
|
|
342
|
+
rule.apply(api_info, color_columns, self.mode_config.dump_mode)
|
|
343
|
+
else:
|
|
344
|
+
for rule in HighlightRules.compare_rules.values():
|
|
345
|
+
rule.apply(api_info, color_columns, self.mode_config.dump_mode)
|
|
363
346
|
|
|
364
|
-
|
|
347
|
+
def highlight_rows_xlsx(self, result_df, highlight_dict, file_path):
|
|
348
|
+
"""Write and highlight results in Excel"""
|
|
365
349
|
|
|
366
|
-
|
|
350
|
+
self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
367
351
|
|
|
368
|
-
|
|
369
|
-
ws.
|
|
352
|
+
wb = openpyxl.Workbook()
|
|
353
|
+
ws = wb.active
|
|
370
354
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
col_len = len(result_df.columns)
|
|
374
|
-
red_fill = PatternFill(
|
|
375
|
-
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
376
|
-
)
|
|
377
|
-
yellow_fill = PatternFill(
|
|
378
|
-
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
379
|
-
)
|
|
380
|
-
for i in highlight_dict.get("red_rows", []):
|
|
381
|
-
for j in range(1, col_len + 1):
|
|
382
|
-
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
383
|
-
for i in highlight_dict.get("yellow_rows", []):
|
|
384
|
-
for j in range(1, col_len + 1):
|
|
385
|
-
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
355
|
+
# write header
|
|
356
|
+
logger.info('Initializing Excel file.')
|
|
386
357
|
|
|
387
|
-
|
|
388
|
-
save_workbook(wb, file_path)
|
|
358
|
+
self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df)
|
|
389
359
|
|
|
360
|
+
result_df_convert = result_df.applymap(self.compare_result_df_convert)
|
|
390
361
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
return
|
|
362
|
+
for row in dataframe_to_rows(result_df_convert, index=False, header=True):
|
|
363
|
+
ws.append(row)
|
|
394
364
|
|
|
395
|
-
|
|
396
|
-
|
|
365
|
+
# 对可疑数据标色
|
|
366
|
+
logger.info('Coloring Excel in progress.')
|
|
367
|
+
col_len = len(result_df.columns)
|
|
368
|
+
red_fill = PatternFill(
|
|
369
|
+
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
370
|
+
)
|
|
371
|
+
yellow_fill = PatternFill(
|
|
372
|
+
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
373
|
+
)
|
|
374
|
+
for i in highlight_dict.get("red_rows", []):
|
|
375
|
+
for j in range(1, col_len + 1):
|
|
376
|
+
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
377
|
+
for i in highlight_dict.get("yellow_rows", []):
|
|
378
|
+
for j in range(1, col_len + 1):
|
|
379
|
+
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
397
380
|
|
|
398
|
-
|
|
399
|
-
|
|
381
|
+
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
382
|
+
save_workbook(wb, file_path)
|
|
400
383
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
for line_index, messages in lines:
|
|
405
|
-
if color == 'yellow' and line_index in red_lines_num_set:
|
|
406
|
-
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
384
|
+
def handle_multi_process_malicious_value_check(self, func, result_df):
|
|
385
|
+
result_total_nums = len(result_df)
|
|
386
|
+
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
407
387
|
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
388
|
+
if result_total_nums <= process_num:
|
|
389
|
+
process_num = 1
|
|
390
|
+
chunks = [result_df]
|
|
391
|
+
else:
|
|
392
|
+
chunk_size = result_total_nums // process_num
|
|
393
|
+
chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
|
|
413
394
|
|
|
414
|
-
|
|
415
|
-
red_lines_num_set.add(line_index)
|
|
395
|
+
pool = multiprocessing.Pool(process_num)
|
|
416
396
|
|
|
417
|
-
|
|
397
|
+
def err_call(args):
|
|
398
|
+
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
399
|
+
try:
|
|
400
|
+
pool.close()
|
|
401
|
+
except OSError:
|
|
402
|
+
logger.error("Pool terminate failed")
|
|
403
|
+
|
|
404
|
+
result_df_columns = result_df.columns.tolist()
|
|
405
|
+
for column in result_df_columns:
|
|
406
|
+
self.value_check(column)
|
|
407
|
+
for df_chunk in chunks:
|
|
408
|
+
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
409
|
+
|
|
410
|
+
pool.close()
|
|
411
|
+
pool.join()
|
|
412
|
+
|
|
413
|
+
def df_malicious_value_check(self, df_chunk, result_df_columns):
|
|
414
|
+
for row in df_chunk.itertuples(index=False):
|
|
415
|
+
api_name = row[0]
|
|
416
|
+
for i, value in enumerate(row):
|
|
417
|
+
self.value_check(value, api_name, i, result_df_columns)
|
|
@@ -164,6 +164,8 @@ def preprocess_layer_mapping(mapping):
|
|
|
164
164
|
for key, value in name_map.items():
|
|
165
165
|
key_list = key.split('.')
|
|
166
166
|
prefix = key_list[0] # 取前缀
|
|
167
|
+
value_list = value.split('(')
|
|
168
|
+
value = value_list[0] # 取前缀
|
|
167
169
|
key_len = len(key_list)
|
|
168
170
|
if prefix not in final_mapping[type_name]:
|
|
169
171
|
final_mapping[type_name][prefix] = []
|
|
@@ -33,8 +33,8 @@ def check_compare_result_name(file_name):
|
|
|
33
33
|
"""
|
|
34
34
|
check whether the compare result name is as expected
|
|
35
35
|
"""
|
|
36
|
-
single_rank_pattern = r"^
|
|
37
|
-
multi_ranks_pattern = r"^compare_result_rank(\d+)
|
|
36
|
+
single_rank_pattern = r"^compare_result_(rank|rank-rank)_\d{14}\.xlsx$"
|
|
37
|
+
multi_ranks_pattern = r"^compare_result_rank(\d+)(?:-rank\1)?_\d{14}\.xlsx$"
|
|
38
38
|
if re.match(multi_ranks_pattern, file_name):
|
|
39
39
|
return True
|
|
40
40
|
if re.match(single_rank_pattern, file_name):
|
|
@@ -48,7 +48,7 @@ def reorder_path(compare_result_path_list):
|
|
|
48
48
|
"""
|
|
49
49
|
reorder compare results by rank num
|
|
50
50
|
"""
|
|
51
|
-
rank_pattern = r"compare_result_rank(\d+)
|
|
51
|
+
rank_pattern = r"compare_result_rank(\d+)"
|
|
52
52
|
reorder_path_list = sorted(
|
|
53
53
|
compare_result_path_list,
|
|
54
54
|
key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1))
|
|
@@ -238,7 +238,7 @@ def handle_multi_process(func, func_args, lock):
|
|
|
238
238
|
def err_call(args):
|
|
239
239
|
logger.error('Multiprocess merge result failed! Reason: {}'.format(args))
|
|
240
240
|
try:
|
|
241
|
-
pool.
|
|
241
|
+
pool.close()
|
|
242
242
|
except OSError:
|
|
243
243
|
logger.error("Pool terminate failed")
|
|
244
244
|
|