mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -30,12 +30,7 @@ from msprobe.core.common.file_utils import save_workbook
|
|
|
30
30
|
from msprobe.core.common.log import logger
|
|
31
31
|
from msprobe.core.common.utils import get_header_index, safe_get_value
|
|
32
32
|
from msprobe.core.compare.utils import table_value_is_valid, get_name_and_state, CompareException
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class HighlightCheck(abc.ABC):
|
|
36
|
-
@abc.abstractmethod
|
|
37
|
-
def apply(self, info, color_columns, dump_mode):
|
|
38
|
-
raise NotImplementedError
|
|
33
|
+
from msprobe.core.compare.config import ModeConfig
|
|
39
34
|
|
|
40
35
|
|
|
41
36
|
def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
@@ -46,6 +41,12 @@ def add_highlight_row_info(color_list, num, highlight_err_msg):
|
|
|
46
41
|
color_list.append((num, [highlight_err_msg]))
|
|
47
42
|
|
|
48
43
|
|
|
44
|
+
class HighlightCheck(abc.ABC):
|
|
45
|
+
@abc.abstractmethod
|
|
46
|
+
def apply(self, info, color_columns, dump_mode):
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
|
|
49
50
|
class CheckOrderMagnitude(HighlightCheck):
|
|
50
51
|
"""检查Max diff的数量级差异"""
|
|
51
52
|
|
|
@@ -75,12 +76,12 @@ class CheckOneThousandErrorRatio(HighlightCheck):
|
|
|
75
76
|
if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
|
|
76
77
|
api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
|
|
77
78
|
add_highlight_row_info(color_columns.red, num,
|
|
78
|
-
"The input/
|
|
79
|
+
"The input/parameter's one thousandth err ratio exceeds 0.9, "
|
|
79
80
|
"while the output's is below 0.6")
|
|
80
81
|
elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
|
|
81
82
|
add_highlight_row_info(color_columns.yellow, num,
|
|
82
83
|
"The output's one thousandth err ratio decreases by more than 0.1 "
|
|
83
|
-
"compared to the input/
|
|
84
|
+
"compared to the input/parameter's")
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
class CheckCosineSimilarity(HighlightCheck):
|
|
@@ -94,7 +95,7 @@ class CheckCosineSimilarity(HighlightCheck):
|
|
|
94
95
|
if api_in[cosine_index] - api_out[cosine_index] > CompareConst.COSINE_DIFF_YELLOW:
|
|
95
96
|
add_highlight_row_info(color_columns.yellow, num,
|
|
96
97
|
"The output's cosine decreases by more than 0.1 "
|
|
97
|
-
"compared to the input/
|
|
98
|
+
"compared to the input/parameter's")
|
|
98
99
|
|
|
99
100
|
|
|
100
101
|
class CheckMaxRelativeDiff(HighlightCheck):
|
|
@@ -117,7 +118,7 @@ class CheckMaxRelativeDiff(HighlightCheck):
|
|
|
117
118
|
input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
|
|
118
119
|
add_highlight_row_info(color_columns.yellow, num,
|
|
119
120
|
"The output's maximum relative error exceeds 0.1, "
|
|
120
|
-
"while the input/
|
|
121
|
+
"while the input/parameter's is below 0.01")
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
class CheckOverflow(HighlightCheck):
|
|
@@ -146,84 +147,19 @@ class HighlightRules:
|
|
|
146
147
|
}
|
|
147
148
|
|
|
148
149
|
# 用于比较输入和输出的规则
|
|
150
|
+
# 真实数据检查规则
|
|
149
151
|
compare_rules = {
|
|
150
152
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
151
153
|
"check_one_thousand_error": CheckOneThousandErrorRatio(),
|
|
152
154
|
"check_cosine_similarity": CheckCosineSimilarity()
|
|
153
155
|
}
|
|
156
|
+
# 统计量数据检查规则
|
|
154
157
|
summary_compare_rules = {
|
|
155
158
|
"check_order_magnitude": CheckOrderMagnitude(),
|
|
156
159
|
"check_max_relative_diff": CheckMaxRelativeDiff(),
|
|
157
160
|
}
|
|
158
161
|
|
|
159
162
|
|
|
160
|
-
def check_indices_numeric(api_items, indices: list):
|
|
161
|
-
"""检查指定索引处的值是否都为数字类型(int 或 float)"""
|
|
162
|
-
return all(isinstance(api_items[i], (float, int)) for i in indices)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def apply_comparison_rules(api_info, dump_mode, color_columns):
|
|
166
|
-
"""output与input/params的比较"""
|
|
167
|
-
if dump_mode == Const.SUMMARY:
|
|
168
|
-
for rule in HighlightRules.summary_compare_rules.values():
|
|
169
|
-
rule.apply(api_info, color_columns, dump_mode)
|
|
170
|
-
else:
|
|
171
|
-
for rule in HighlightRules.compare_rules.values():
|
|
172
|
-
rule.apply(api_info, color_columns, dump_mode)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def find_error_rows(result, api_batch, highlight_dict, dump_mode):
|
|
176
|
-
"""找到单个API中需要高亮的行"""
|
|
177
|
-
if dump_mode == Const.MD5:
|
|
178
|
-
return
|
|
179
|
-
npu_max_index = get_header_index(CompareConst.NPU_MAX, dump_mode)
|
|
180
|
-
bench_max_index = get_header_index(CompareConst.BENCH_MAX, dump_mode)
|
|
181
|
-
max_diff_index = get_header_index(CompareConst.MAX_DIFF if dump_mode == Const.SUMMARY
|
|
182
|
-
else CompareConst.MAX_ABS_ERR, dump_mode)
|
|
183
|
-
|
|
184
|
-
red_lines, yellow_lines = [], []
|
|
185
|
-
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
186
|
-
ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
|
|
187
|
-
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
188
|
-
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
189
|
-
|
|
190
|
-
api_batch_start = api_batch.start # result_df的input起始全局索引
|
|
191
|
-
api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
|
|
192
|
-
api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
|
|
193
|
-
api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
|
|
194
|
-
api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
|
|
195
|
-
|
|
196
|
-
# 对单行API的输入或输出进行误差判断
|
|
197
|
-
for i, line in enumerate(result):
|
|
198
|
-
index = api_batch_start + i
|
|
199
|
-
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
200
|
-
for rule in HighlightRules.basic_rules.values():
|
|
201
|
-
rule.apply(line_info, color_columns, dump_mode)
|
|
202
|
-
|
|
203
|
-
# 对API的输出与输入比较,进行误差判断
|
|
204
|
-
for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
|
|
205
|
-
index = api_batch_start + api_batch_params_slice_index_local + n
|
|
206
|
-
# 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
|
|
207
|
-
if index in red_lines:
|
|
208
|
-
continue
|
|
209
|
-
if not check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
|
|
210
|
-
continue
|
|
211
|
-
|
|
212
|
-
# input/parameters的比较检查, 这里api_in包括input、parameters
|
|
213
|
-
for _, api_in in enumerate(result[0: api_batch_params_slice_index_local]):
|
|
214
|
-
if not check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
|
|
215
|
-
continue
|
|
216
|
-
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
217
|
-
apply_comparison_rules(api_info, dump_mode, color_columns)
|
|
218
|
-
|
|
219
|
-
red_lines_num_set = {x[0] for x in red_lines}
|
|
220
|
-
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
221
|
-
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
222
|
-
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
223
|
-
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
224
|
-
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
225
|
-
|
|
226
|
-
|
|
227
163
|
class ApiBatch:
|
|
228
164
|
def __init__(self, api_name: str, start: int):
|
|
229
165
|
self.api_name = api_name
|
|
@@ -257,159 +193,225 @@ class ApiBatch:
|
|
|
257
193
|
self.params_grad_end_index += 1
|
|
258
194
|
|
|
259
195
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
input: [start: start+input_len]
|
|
264
|
-
output: [start+input_len: output_end_index]
|
|
265
|
-
params: [output_end_index: params_end_index]
|
|
266
|
-
"""
|
|
267
|
-
if not api_batches:
|
|
268
|
-
api_batches.append(ApiBatch(api_name, index))
|
|
269
|
-
else:
|
|
270
|
-
api_batch = api_batches[-1]
|
|
271
|
-
if api_batch.api_name == api_name or (
|
|
272
|
-
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
273
|
-
try:
|
|
274
|
-
api_batch.increment(state)
|
|
275
|
-
except ValueError as e:
|
|
276
|
-
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
277
|
-
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
278
|
-
else:
|
|
279
|
-
api_batches.append(ApiBatch(api_name, index))
|
|
196
|
+
class HighLight:
|
|
197
|
+
def __init__(self, mode_config: ModeConfig):
|
|
198
|
+
self.mode_config = mode_config
|
|
280
199
|
|
|
281
|
-
|
|
282
|
-
def
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
for api_batch in api_batches:
|
|
292
|
-
find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch, highlight_dict,
|
|
293
|
-
dump_mode)
|
|
294
|
-
progress_bar.update(1)
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
def value_check(value, api_name=None, i=None, result_df_columns=None):
|
|
298
|
-
if not table_value_is_valid(value):
|
|
299
|
-
if result_df_columns:
|
|
300
|
-
logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
|
|
301
|
-
f"is not allowed to be written into the compare result xlsx.")
|
|
200
|
+
@staticmethod
|
|
201
|
+
def api_batches_update(api_batches, api_name, state, index):
|
|
202
|
+
"""
|
|
203
|
+
当一个api的所有item更新完后,input, output的索引范围:
|
|
204
|
+
input: [start: start+input_len]
|
|
205
|
+
output: [start+input_len: output_end_index]
|
|
206
|
+
params: [output_end_index: params_end_index]
|
|
207
|
+
"""
|
|
208
|
+
if not api_batches:
|
|
209
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
302
210
|
else:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
pool = multiprocessing.Pool(process_num)
|
|
325
|
-
|
|
326
|
-
def err_call(args):
|
|
327
|
-
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
328
|
-
try:
|
|
329
|
-
pool.terminate()
|
|
330
|
-
except OSError:
|
|
331
|
-
logger.error("Pool terminate failed")
|
|
332
|
-
|
|
333
|
-
result_df_columns = result_df.columns.tolist()
|
|
334
|
-
for column in result_df_columns:
|
|
335
|
-
value_check(column)
|
|
336
|
-
for df_chunk in chunks:
|
|
337
|
-
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
338
|
-
|
|
339
|
-
pool.close()
|
|
340
|
-
pool.join()
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def compare_result_df_convert(value):
|
|
344
|
-
if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
|
|
345
|
-
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
|
|
346
|
-
if isinstance(value, float):
|
|
347
|
-
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
|
|
348
|
-
return value
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
def highlight_rows_xlsx(result_df, highlight_dict, file_path):
|
|
352
|
-
"""Write and highlight results in Excel"""
|
|
211
|
+
api_batch = api_batches[-1]
|
|
212
|
+
if api_batch.api_name == api_name or (
|
|
213
|
+
not re.search(Const.REGEX_FORWARD_BACKWARD, api_name) and api_name in api_batch.api_name):
|
|
214
|
+
try:
|
|
215
|
+
api_batch.increment(state)
|
|
216
|
+
except ValueError as e:
|
|
217
|
+
logger.error(f"api_batch: {api_batch} with invalid state, please check! {e}")
|
|
218
|
+
raise CompareException(CompareException.INVALID_STATE_ERROR) from e
|
|
219
|
+
else:
|
|
220
|
+
api_batches.append(ApiBatch(api_name, index))
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def check_indices_numeric(api_items, indices: list):
|
|
224
|
+
"""检查指定索引处的值是否都为数字类型(int 或 float)"""
|
|
225
|
+
return all(isinstance(api_items[i], (float, int)) for i in indices)
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def update_highlight_err_msg(result_df, highlight_dict):
|
|
229
|
+
if result_df.shape[1] <= 1:
|
|
230
|
+
return
|
|
353
231
|
|
|
354
|
-
|
|
232
|
+
if CompareConst.NPU_MD5 in result_df.columns:
|
|
233
|
+
return
|
|
355
234
|
|
|
356
|
-
|
|
357
|
-
|
|
235
|
+
err_msg = result_df.get(CompareConst.ERROR_MESSAGE)
|
|
236
|
+
red_lines_num_set = highlight_dict.get('red_rows')
|
|
237
|
+
|
|
238
|
+
for color in ['red', 'yellow']:
|
|
239
|
+
line_key = f'{color}_lines'
|
|
240
|
+
lines = highlight_dict.get(line_key, [])
|
|
241
|
+
for line_index, messages in lines:
|
|
242
|
+
if color == 'yellow' and line_index in red_lines_num_set:
|
|
243
|
+
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
244
|
+
|
|
245
|
+
for msg in messages:
|
|
246
|
+
if err_msg[line_index] == '':
|
|
247
|
+
err_msg[line_index] = msg
|
|
248
|
+
else:
|
|
249
|
+
err_msg[line_index] += '\n' + msg
|
|
250
|
+
|
|
251
|
+
if color == 'red':
|
|
252
|
+
red_lines_num_set.add(line_index)
|
|
253
|
+
|
|
254
|
+
result_df[CompareConst.ERROR_MESSAGE] = err_msg
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def compare_result_df_convert(value):
|
|
258
|
+
if not isinstance(value, (float, int)) or isinstance(value, bool): # bool类型或者非数字类型转str
|
|
259
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else str(value)
|
|
260
|
+
if isinstance(value, float):
|
|
261
|
+
value = f"{str(value)}\t" if str(value) in ("inf", "-inf", "nan") else value
|
|
262
|
+
return value
|
|
263
|
+
|
|
264
|
+
@staticmethod
|
|
265
|
+
def value_check(value, api_name=None, i=None, result_df_columns=None):
|
|
266
|
+
if not table_value_is_valid(value):
|
|
267
|
+
if result_df_columns:
|
|
268
|
+
logger.error(f"Malicious value [{value}] at api_name [{api_name}], column [{result_df_columns[i]}], "
|
|
269
|
+
f"is not allowed to be written into the compare result xlsx.")
|
|
270
|
+
else:
|
|
271
|
+
logger.error(f"Malicious value [{value}] is not allowed to be written into the compare result xlsx.")
|
|
272
|
+
|
|
273
|
+
def find_compare_result_error_rows(self, result_df, highlight_dict):
|
|
274
|
+
"""将dataframe根据API分组,并找到有误差的算子用于高亮"""
|
|
275
|
+
result = result_df.values
|
|
276
|
+
api_batches = []
|
|
277
|
+
for i, res_i in enumerate(result):
|
|
278
|
+
api_full_name = safe_get_value(res_i, 0, "res_i")
|
|
279
|
+
api_name, state = get_name_and_state(api_full_name)
|
|
280
|
+
self.api_batches_update(api_batches, api_name, state, i)
|
|
281
|
+
with tqdm(total=len(api_batches), desc="API/Module Analyse Progress", unit="item", ncols=100) as progress_bar:
|
|
282
|
+
for api_batch in api_batches:
|
|
283
|
+
self.find_error_rows(result[api_batch.start: api_batch.params_grad_end_index], api_batch,
|
|
284
|
+
highlight_dict)
|
|
285
|
+
progress_bar.update(1)
|
|
286
|
+
|
|
287
|
+
def find_error_rows(self, result, api_batch, highlight_dict):
|
|
288
|
+
"""找到单个API中需要高亮的行"""
|
|
289
|
+
if self.mode_config.dump_mode == Const.MD5:
|
|
290
|
+
return
|
|
291
|
+
npu_max_index = get_header_index(CompareConst.NPU_MAX, self.mode_config.dump_mode)
|
|
292
|
+
bench_max_index = get_header_index(CompareConst.BENCH_MAX, self.mode_config.dump_mode)
|
|
293
|
+
max_diff_index = get_header_index(CompareConst.MAX_DIFF if self.mode_config.dump_mode == Const.SUMMARY
|
|
294
|
+
else CompareConst.MAX_ABS_ERR, self.mode_config.dump_mode)
|
|
295
|
+
|
|
296
|
+
red_lines, yellow_lines = [], []
|
|
297
|
+
LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer'])
|
|
298
|
+
ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer'])
|
|
299
|
+
ColorColumns = namedtuple('ColorColumns', ['red', 'yellow'])
|
|
300
|
+
color_columns = ColorColumns(red=red_lines, yellow=yellow_lines)
|
|
301
|
+
|
|
302
|
+
api_batch_start = api_batch.start # result_df的input起始全局索引
|
|
303
|
+
api_batch_params_end_index = api_batch.params_end_index # result_df的params结束全局索引 + 1
|
|
304
|
+
api_batch_output_end_index = api_batch.output_end_index # result_df的output结束全局索引 + 1
|
|
305
|
+
api_batch_params_slice_index_local = api_batch_params_end_index - api_batch_start # result的params结束局部切片索引
|
|
306
|
+
api_batch_output_slice_index_local = api_batch_output_end_index - api_batch_start # result的output结束局部切片索引
|
|
307
|
+
|
|
308
|
+
# 对单行API的输入或输出进行误差判断
|
|
309
|
+
for i, line in enumerate(result):
|
|
310
|
+
index = api_batch_start + i
|
|
311
|
+
line_info = LineInfo(line_data=line, num_pointer=index)
|
|
312
|
+
for rule in HighlightRules.basic_rules.values():
|
|
313
|
+
rule.apply(line_info, color_columns, self.mode_config.dump_mode)
|
|
314
|
+
|
|
315
|
+
# 对API的输出与输入比较,进行误差判断
|
|
316
|
+
for n, api_out in enumerate(result[api_batch_params_slice_index_local: api_batch_output_slice_index_local]):
|
|
317
|
+
index = api_batch_start + api_batch_params_slice_index_local + n
|
|
318
|
+
# 单行检查只有溢出检查(红色),如果已经溢出,不进一步检查
|
|
319
|
+
if index in red_lines:
|
|
320
|
+
continue
|
|
321
|
+
if not self.check_indices_numeric(api_out, [npu_max_index, bench_max_index, max_diff_index]):
|
|
322
|
+
continue
|
|
358
323
|
|
|
359
|
-
|
|
360
|
-
|
|
324
|
+
# input/parameters的比较检查, 这里api_in包括input、parameters
|
|
325
|
+
for api_in in result[0: api_batch_params_slice_index_local]:
|
|
326
|
+
if not self.check_indices_numeric(api_in, [npu_max_index, bench_max_index, max_diff_index]):
|
|
327
|
+
continue
|
|
328
|
+
api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=index)
|
|
329
|
+
self.apply_comparison_rules(api_info, color_columns)
|
|
330
|
+
|
|
331
|
+
red_lines_num_set = {x[0] for x in red_lines}
|
|
332
|
+
yellow_lines_num_set = {x[0] for x in yellow_lines}
|
|
333
|
+
highlight_dict.get('red_rows', set()).update(red_lines_num_set)
|
|
334
|
+
highlight_dict.get('yellow_rows', set()).update(yellow_lines_num_set - red_lines_num_set)
|
|
335
|
+
highlight_dict.get('red_lines', []).extend(red_lines)
|
|
336
|
+
highlight_dict.get('yellow_lines', []).extend(yellow_lines)
|
|
337
|
+
|
|
338
|
+
def apply_comparison_rules(self, api_info, color_columns):
|
|
339
|
+
"""output与input/params的比较"""
|
|
340
|
+
if self.mode_config.dump_mode == Const.SUMMARY:
|
|
341
|
+
for rule in HighlightRules.summary_compare_rules.values():
|
|
342
|
+
rule.apply(api_info, color_columns, self.mode_config.dump_mode)
|
|
343
|
+
else:
|
|
344
|
+
for rule in HighlightRules.compare_rules.values():
|
|
345
|
+
rule.apply(api_info, color_columns, self.mode_config.dump_mode)
|
|
361
346
|
|
|
362
|
-
|
|
347
|
+
def highlight_rows_xlsx(self, result_df, highlight_dict, file_path):
|
|
348
|
+
"""Write and highlight results in Excel"""
|
|
363
349
|
|
|
364
|
-
|
|
350
|
+
self.update_highlight_err_msg(result_df, highlight_dict) # add highlight err_msg
|
|
365
351
|
|
|
366
|
-
|
|
367
|
-
ws.
|
|
352
|
+
wb = openpyxl.Workbook()
|
|
353
|
+
ws = wb.active
|
|
368
354
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
col_len = len(result_df.columns)
|
|
372
|
-
red_fill = PatternFill(
|
|
373
|
-
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
374
|
-
)
|
|
375
|
-
yellow_fill = PatternFill(
|
|
376
|
-
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
377
|
-
)
|
|
378
|
-
for i in highlight_dict.get("red_rows", []):
|
|
379
|
-
for j in range(1, col_len + 1):
|
|
380
|
-
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
381
|
-
for i in highlight_dict.get("yellow_rows", []):
|
|
382
|
-
for j in range(1, col_len + 1):
|
|
383
|
-
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
355
|
+
# write header
|
|
356
|
+
logger.info('Initializing Excel file.')
|
|
384
357
|
|
|
385
|
-
|
|
386
|
-
save_workbook(wb, file_path)
|
|
358
|
+
self.handle_multi_process_malicious_value_check(self.df_malicious_value_check, result_df)
|
|
387
359
|
|
|
360
|
+
result_df_convert = result_df.applymap(self.compare_result_df_convert)
|
|
388
361
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
return
|
|
362
|
+
for row in dataframe_to_rows(result_df_convert, index=False, header=True):
|
|
363
|
+
ws.append(row)
|
|
392
364
|
|
|
393
|
-
|
|
394
|
-
|
|
365
|
+
# 对可疑数据标色
|
|
366
|
+
logger.info('Coloring Excel in progress.')
|
|
367
|
+
col_len = len(result_df.columns)
|
|
368
|
+
red_fill = PatternFill(
|
|
369
|
+
start_color=CompareConst.RED, end_color=CompareConst.RED, fill_type="solid"
|
|
370
|
+
)
|
|
371
|
+
yellow_fill = PatternFill(
|
|
372
|
+
start_color=CompareConst.YELLOW, end_color=CompareConst.YELLOW, fill_type="solid",
|
|
373
|
+
)
|
|
374
|
+
for i in highlight_dict.get("red_rows", []):
|
|
375
|
+
for j in range(1, col_len + 1):
|
|
376
|
+
ws.cell(row=i + 2, column=j).fill = red_fill # 2因为ws.cell中的row或column需要>=1,数据从第2行开始
|
|
377
|
+
for i in highlight_dict.get("yellow_rows", []):
|
|
378
|
+
for j in range(1, col_len + 1):
|
|
379
|
+
ws.cell(row=i + 2, column=j).fill = yellow_fill
|
|
395
380
|
|
|
396
|
-
|
|
397
|
-
|
|
381
|
+
logger.info('Saving Excel file to disk: %s' % file_path)
|
|
382
|
+
save_workbook(wb, file_path)
|
|
398
383
|
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
for line_index, messages in lines:
|
|
403
|
-
if color == 'yellow' and line_index in red_lines_num_set:
|
|
404
|
-
continue # 如果是 yellow 行,且已被 red 行覆盖,跳过
|
|
384
|
+
def handle_multi_process_malicious_value_check(self, func, result_df):
|
|
385
|
+
result_total_nums = len(result_df)
|
|
386
|
+
process_num = int((multiprocessing.cpu_count() + 1) / 2)
|
|
405
387
|
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
388
|
+
if result_total_nums <= process_num:
|
|
389
|
+
process_num = 1
|
|
390
|
+
chunks = [result_df]
|
|
391
|
+
else:
|
|
392
|
+
chunk_size = result_total_nums // process_num
|
|
393
|
+
chunks = [result_df.iloc[i: i + chunk_size] for i in range(0, result_total_nums, chunk_size)]
|
|
411
394
|
|
|
412
|
-
|
|
413
|
-
red_lines_num_set.add(line_index)
|
|
395
|
+
pool = multiprocessing.Pool(process_num)
|
|
414
396
|
|
|
415
|
-
|
|
397
|
+
def err_call(args):
|
|
398
|
+
logger.error("Multiprocessing malicious value check failed! Reason: {}".format(args))
|
|
399
|
+
try:
|
|
400
|
+
pool.close()
|
|
401
|
+
except OSError:
|
|
402
|
+
logger.error("Pool terminate failed")
|
|
403
|
+
|
|
404
|
+
result_df_columns = result_df.columns.tolist()
|
|
405
|
+
for column in result_df_columns:
|
|
406
|
+
self.value_check(column)
|
|
407
|
+
for df_chunk in chunks:
|
|
408
|
+
pool.apply_async(func, args=(df_chunk, result_df_columns,), error_callback=err_call)
|
|
409
|
+
|
|
410
|
+
pool.close()
|
|
411
|
+
pool.join()
|
|
412
|
+
|
|
413
|
+
def df_malicious_value_check(self, df_chunk, result_df_columns):
|
|
414
|
+
for row in df_chunk.itertuples(index=False):
|
|
415
|
+
api_name = row[0]
|
|
416
|
+
for i, value in enumerate(row):
|
|
417
|
+
self.value_check(value, api_name, i, result_df_columns)
|
|
@@ -23,7 +23,7 @@ from msprobe.core.common.utils import (add_time_with_yaml,
|
|
|
23
23
|
get_stack_construct_by_dump_json_path)
|
|
24
24
|
from msprobe.core.compare.layer_mapping.data_scope_parser import get_dump_data_items
|
|
25
25
|
from msprobe.core.compare.utils import read_op, reorder_op_name_list
|
|
26
|
-
|
|
26
|
+
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class LayerTrie:
|
|
@@ -71,6 +71,7 @@ class LayerTrie:
|
|
|
71
71
|
file_path = os.path.join(os.path.realpath(output_path), file_name)
|
|
72
72
|
save_yaml(file_path, result)
|
|
73
73
|
|
|
74
|
+
@recursion_depth_decorator("LayerMapping: LayerTrie.convert_to_dict", max_depth=100)
|
|
74
75
|
def convert_to_dict(self, node):
|
|
75
76
|
result = {}
|
|
76
77
|
result["data_item"] = {st: [dt.data_name for dt in dts] for st, dts in node.data_items.items()}
|
|
@@ -163,6 +164,8 @@ def preprocess_layer_mapping(mapping):
|
|
|
163
164
|
for key, value in name_map.items():
|
|
164
165
|
key_list = key.split('.')
|
|
165
166
|
prefix = key_list[0] # 取前缀
|
|
167
|
+
value_list = value.split('(')
|
|
168
|
+
value = value_list[0] # 取前缀
|
|
166
169
|
key_len = len(key_list)
|
|
167
170
|
if prefix not in final_mapping[type_name]:
|
|
168
171
|
final_mapping[type_name][prefix] = []
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (c) 2024-
|
|
1
|
+
# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
|
|
2
2
|
# All rights reserved.
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -21,7 +21,8 @@ from functools import partial
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from tqdm import tqdm
|
|
23
23
|
|
|
24
|
-
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory
|
|
24
|
+
from msprobe.core.common.file_utils import load_yaml, logger, FileChecker, save_excel, read_xlsx, create_directory, \
|
|
25
|
+
remove_path
|
|
25
26
|
from msprobe.core.common.const import FileCheckConst, Const, CompareConst
|
|
26
27
|
from msprobe.core.common.utils import CompareException, add_time_with_xlsx
|
|
27
28
|
from msprobe.core.compare.utils import table_value_is_valid
|
|
@@ -32,8 +33,8 @@ def check_compare_result_name(file_name):
|
|
|
32
33
|
"""
|
|
33
34
|
check whether the compare result name is as expected
|
|
34
35
|
"""
|
|
35
|
-
single_rank_pattern = r"^
|
|
36
|
-
multi_ranks_pattern = r"^compare_result_rank(\d+)
|
|
36
|
+
single_rank_pattern = r"^compare_result_(rank|rank-rank)_\d{14}\.xlsx$"
|
|
37
|
+
multi_ranks_pattern = r"^compare_result_rank(\d+)(?:-rank\1)?_\d{14}\.xlsx$"
|
|
37
38
|
if re.match(multi_ranks_pattern, file_name):
|
|
38
39
|
return True
|
|
39
40
|
if re.match(single_rank_pattern, file_name):
|
|
@@ -47,7 +48,7 @@ def reorder_path(compare_result_path_list):
|
|
|
47
48
|
"""
|
|
48
49
|
reorder compare results by rank num
|
|
49
50
|
"""
|
|
50
|
-
rank_pattern = r"compare_result_rank(\d+)
|
|
51
|
+
rank_pattern = r"compare_result_rank(\d+)"
|
|
51
52
|
reorder_path_list = sorted(
|
|
52
53
|
compare_result_path_list,
|
|
53
54
|
key=lambda path: int(re.search(rank_pattern, os.path.basename(path)).group(1))
|
|
@@ -63,6 +64,7 @@ def get_result_path(input_dir):
|
|
|
63
64
|
for f in os.listdir(input_dir) if f.endswith(FileCheckConst.XLSX_SUFFIX)]
|
|
64
65
|
filt_compare_result_path_list = []
|
|
65
66
|
for file_path in compare_result_path_list:
|
|
67
|
+
FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check()
|
|
66
68
|
file_name = os.path.basename(file_path)
|
|
67
69
|
if check_compare_result_name(file_name):
|
|
68
70
|
compare_result_path_checker = FileChecker(file_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE)
|
|
@@ -236,7 +238,7 @@ def handle_multi_process(func, func_args, lock):
|
|
|
236
238
|
def err_call(args):
|
|
237
239
|
logger.error('Multiprocess merge result failed! Reason: {}'.format(args))
|
|
238
240
|
try:
|
|
239
|
-
pool.
|
|
241
|
+
pool.close()
|
|
240
242
|
except OSError:
|
|
241
243
|
logger.error("Pool terminate failed")
|
|
242
244
|
|
|
@@ -329,6 +331,10 @@ def generate_merge_result(all_compare_index_dict_list, all_rank_num_list, all_co
|
|
|
329
331
|
for i, df in enumerate(merge_df_list):
|
|
330
332
|
# merge_df_list中df与compare_index_list中compare_index一一对应
|
|
331
333
|
final_result_df_list.append((df, compare_index_list[i]))
|
|
334
|
+
|
|
335
|
+
if os.path.exists(output_path):
|
|
336
|
+
logger.warning(f"{output_path} will be deleted.")
|
|
337
|
+
remove_path(output_path)
|
|
332
338
|
save_excel(output_path, final_result_df_list)
|
|
333
339
|
logger.info(f"The compare results of the multi-ranks are merged and saved in: {output_path}.")
|
|
334
340
|
|